diff --git a/models/synapses/stdp_synapse.nestml b/models/synapses/stdp_synapse.nestml index 30ef63968..c002bf916 100644 --- a/models/synapses/stdp_synapse.nestml +++ b/models/synapses/stdp_synapse.nestml @@ -33,22 +33,26 @@ References Stable Hebbian learning from spike timing-dependent plasticity, Journal of Neuroscience, 20:23,8812--8821 """ -synapse stdp: +synapse stdp_synapse: state: - w real = 1. @nest::weight # Synaptic weight + w real = 1. @nest::weight # Synaptic weight (> 0 for excitatory and < 0 for inhibitory synapses) pre_trace real = 0. post_trace real = 0. parameters: d ms = 1 ms @nest::delay # Synaptic transmission delay - lambda real = .01 - tau_tr_pre ms = 20 ms - tau_tr_post ms = 20 ms - alpha real = 1 - mu_plus real = 1 - mu_minus real = 1 - Wmax real = 100. - Wmin real = 0. + lambda real = 0.01 # (dimensionless) learning rate for causal updates + alpha real = 1 # relative learning rate for acausal firing + tau_tr_pre ms = 20 ms # time constant of presynaptic trace + tau_tr_post ms = 20 ms # time constant of postsynaptic trace + mu_plus real = 1 # weight dependence exponent for causal updates + mu_minus real = 1 # weight dependence exponent for acausal updates + + Wmax real = 100. # maximum absolute value of synaptic weight + Wmin real = 0. # minimum absolute value of synaptic weight + + internals: + w_sign real = w / abs(w) # sign of synaptic weight equations: pre_trace' = -pre_trace / tau_tr_pre @@ -64,16 +68,18 @@ synapse stdp: onReceive(post_spikes): post_trace += 1 + println("post spike, w_sign = {w_sign}") # potentiate synapse - w_ real = Wmax * ( w / Wmax + (lambda * ( 1. - ( w / Wmax ) )**mu_plus * pre_trace )) - w = min(Wmax, w_) + w_ real = Wmax * ( abs(w) / Wmax + (lambda * ( 1. - ( abs(w) / Wmax ) )**mu_plus * pre_trace )) + w = w_sign * min(Wmax, w_) onReceive(pre_spikes): pre_trace += 1 + println("pre spike, w_sign = {w_sign}") # depress synapse - w_ real = Wmax * ( w / Wmax - ( alpha * lambda * ( w / Wmax )**mu_minus * post_trace )) - w = max(Wmin, w_) + w_ real = Wmax * ( abs(w) / Wmax - ( alpha * lambda * ( abs(w) / Wmax )**mu_minus * post_trace )) + w = w_sign * max(Wmin, w_) # deliver spike to postsynaptic partner deliver_spike(w, d) diff --git a/tests/nest_tests/test_plastic_synapse_weight_sign.py b/tests/nest_tests/test_plastic_synapse_weight_sign.py new file mode 100644 index 000000000..052ecc2c4 --- /dev/null +++ b/tests/nest_tests/test_plastic_synapse_weight_sign.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- +# +# test_plastic_synapse_weight_sign.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +from typing import Sequence + +import numpy as np +import os +import pytest + +import nest + +from pynestml.codegeneration.nest_tools import NESTTools +from pynestml.frontend.pynestml_frontend import generate_nest_target + +try: + import matplotlib + matplotlib.use("Agg") + import matplotlib.ticker + import matplotlib.pyplot as plt + TEST_PLOTS = True +except Exception: + TEST_PLOTS = False + + +synapse_model_names = ["stdp_synapse"]#, "triplet_stdp_synapse", "stdp_nn_symm", "stdp_nn_restr_symm", "stdp_nn_pre_centered"] + +class TestPlasticSynapseWeightSign: + r"""Test that the sign of the weight of plastic synapses never changes (negative stays negative, positive stays positive)""" + + neuron_model_name = "iaf_psc_exp" + + @pytest.fixture(autouse=True, + scope="module") + def generate_model_code(self): + """Generate the model code""" + + codegen_opts = {"neuron_synapse_pairs": []} + + files = [os.path.join("models", "neurons", self.neuron_model_name + ".nestml")] + for synapse_model_name in synapse_model_names: + files.append(os.path.join("models", "synapses", synapse_model_name + ".nestml")) + codegen_opts["neuron_synapse_pairs"].append({"neuron": self.neuron_model_name, + "synapse": synapse_model_name, + "post_ports": ["post_spikes"]}) + + input_path = [os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join(os.pardir, os.pardir, s))) for s in files] + generate_nest_target(input_path=input_path, + logging_level="DEBUG", + suffix="_nestml", + codegen_opts=codegen_opts) + + nest.Install("nestmlmodule") + + @pytest.mark.parametrize("synapse_model_name", synapse_model_names) + @pytest.mark.parametrize("test", ["potentiation", "depression"]) + def test_nest_stdp_synapse(self, synapse_model_name: str, test: str): + pre_spike_times = np.linspace(100., 1000., 10) + + if test == "potentiation": + init_weight = -1. + post_spike_times = pre_spike_times + 10. + else: + init_weight = 1. + post_spike_times = pre_spike_times - 10. + + nest.ResetKernel() + + # create spike_generators with these times + pre_sg = nest.Create("spike_generator", + params={"spike_times": pre_spike_times, + "allow_offgrid_times": True}) + post_sg = nest.Create("spike_generator", + params={"spike_times": post_spike_times, + "allow_offgrid_times": True}) + + pre_neuron = nest.Create("parrot_neuron") + post_neuron = nest.Create(self.neuron_model_name) + + nest.Connect(pre_sg, pre_neuron, syn_spec={"weight": 9999.}) + nest.Connect(post_sg, post_neuron, syn_spec={"weight": 9999.}) + nest.Connect(pre_neuron, post_neuron, syn_spec={"synapse_model": synapse_model_name, + "weight": init_weight}) + + syn = nest.GetConnections(source=pre_neuron, synapse_model=synapse_model_name) + + nest.Simulate(100. + max(np.amax(pre_spike_times), np.amax(post_spike_times))) + + assert np.sign(syn.weight) == 0. # should not pass through zero