Skip to content

Commit

Permalink
Fix and add test for second-order integration in neuromodulated synap…
Browse files Browse the repository at this point in the history
…se (#905)
  • Loading branch information
clinssen authored Aug 9, 2023
1 parent d053d40 commit 924ec9b
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 21 deletions.
9 changes: 6 additions & 3 deletions pynestml/codegeneration/nest_assignments_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def lhs_variable(cls, assignment: ASTAssignment) -> Optional[VariableSymbol]:
if symbol is not None:
return symbol

Logger.log_message(message='No symbol could be resolved!', log_level=LoggingLevel.ERROR)
Logger.log_message(message="No symbol could be resolved for assignment \"" + str(assignment) + "\"!", log_level=LoggingLevel.ERROR)

return None

@classmethod
Expand All @@ -63,7 +64,8 @@ def lhs_vector_variable(cls, assignment: ASTAssignment) -> VariableSymbol:
if symbol is not None:
return symbol

Logger.log_message(message='No symbol could be resolved!', log_level=LoggingLevel.WARNING)
Logger.log_message(message="No symbol could be resolved for assignment \"" + str(assignment) + "\"!", log_level=LoggingLevel.WARNING)

return None

@classmethod
Expand Down Expand Up @@ -115,7 +117,8 @@ def is_vectorized_assignment(cls, assignment) -> bool:

return False

Logger.log_message(message='No symbol could be resolved!', log_level=LoggingLevel.ERROR)
Logger.log_message(message="No symbol could be resolved for assignment \"" + str(assignment) + "\"!", log_level=LoggingLevel.ERROR)

return False

@classmethod
Expand Down
28 changes: 17 additions & 11 deletions pynestml/transformers/synapse_post_neuron_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,18 +258,17 @@ def transform_neuron_synapse_pair_(self, neuron, synapse):
strictly_synaptic_vars = []
for input_block in new_synapse.get_input_blocks():
for port in input_block.get_input_ports():
if not self.is_post_port(port.name, neuron.name, synapse.name):
strictly_synaptic_vars += self.get_all_variables_assigned_to(
synapse.get_on_receive_block(port.name))
if (not self.is_post_port(port.name, neuron.name, synapse.name)) or self.is_vt_port(port.name, neuron.name, synapse.name):
strictly_synaptic_vars += self.get_all_variables_assigned_to(synapse.get_on_receive_block(port.name))

for update_block in synapse.get_update_blocks():
strictly_synaptic_vars += self.get_all_variables_assigned_to(update_block)

convolve_with_not_post_vars = self.get_convolve_with_not_post_vars(
synapse.get_equations_blocks(), neuron.name, synapse.name, synapse)
convolve_with_not_post_vars = self.get_convolve_with_not_post_vars(synapse.get_equations_blocks(), neuron.name, synapse.name, synapse)

syn_to_neuron_state_vars = list(set(all_state_vars) - (set(strictly_synaptic_vars) | set(convolve_with_not_post_vars)))
Logger.log_message(None, -1, "State variables that will be moved from synapse to neuron: " + str(syn_to_neuron_state_vars),
None, LoggingLevel.INFO)
strictly_synaptic_vars_dependent = ASTUtils.recursive_dependent_variables_search(strictly_synaptic_vars, synapse)

syn_to_neuron_state_vars = list(set(all_state_vars) - (set(strictly_synaptic_vars) | set(convolve_with_not_post_vars) | set(strictly_synaptic_vars_dependent)))

#
# collect all the variable/parameter/kernel/function/etc. names used in defining expressions of `syn_to_neuron_state_vars`
Expand All @@ -281,6 +280,15 @@ def transform_neuron_synapse_pair_(self, neuron, synapse):
for neuron_state_var in syn_to_neuron_state_vars
if new_synapse.get_kernel_by_name(neuron_state_var) is None]

# all state variables that will be moved from synapse to neuron
syn_to_neuron_state_vars = []
for var_name in recursive_vars_used:
if ASTUtils.get_state_variable_by_name(synapse, var_name) or ASTUtils.get_inline_expression_by_name(synapse, var_name) or ASTUtils.get_kernel_by_name(synapse, var_name):
syn_to_neuron_state_vars.append(var_name)

Logger.log_message(None, -1, "State variables that will be moved from synapse to neuron: " + str(syn_to_neuron_state_vars),
None, LoggingLevel.INFO)

#
# collect all the parameters
#
Expand Down Expand Up @@ -446,12 +454,10 @@ def mark_post_port(_expr=None):

Logger.log_message(None, -1, "Copying parameters from synapse to neuron...", None, LoggingLevel.INFO)
for param_var in syn_to_neuron_params:
Logger.log_message(None, -1, "\tCopying parameter with name " + str(param_var)
+ " from synapse to neuron", None, LoggingLevel.INFO)
decls = ASTUtils.move_decls(param_var,
new_synapse.get_parameters_blocks()[0],
new_neuron.get_parameters_blocks()[0],
var_name_suffix,
var_name_suffix=var_name_suffix,
block_type=BlockType.PARAMETERS,
mode="copy")

Expand Down
29 changes: 22 additions & 7 deletions pynestml/utils/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,8 +554,8 @@ def replace_var(_expr=None):
else:
return

if not suffix in var.get_name() \
and not var.get_name() == "t":
if not var.get_name() == "t" \
and not var.get_name().endswith(suffix):
var.set_name(var.get_name() + suffix)

astnode.accept(ASTHigherOrderVisitor(lambda x: replace_var(x)))
Expand All @@ -569,6 +569,15 @@ def get_inline_expression_by_name(cls, node, name: str) -> Optional[ASTInlineExp

return None

@classmethod
def get_kernel_by_name(cls, node, name: str) -> Optional[ASTKernel]:
for equations_block in node.get_equations_blocks():
for kernel in equations_block.get_kernels():
if name in kernel.get_variable_names():
return kernel

return None

@classmethod
def replace_with_external_variable(cls, var_name, node: ASTNode, suffix, new_scope, alternate_name=None):
"""
Expand Down Expand Up @@ -712,7 +721,8 @@ def visit_function_call(self, node):
return variables

@classmethod
def move_decls(cls, var_name, from_block, to_block, var_name_suffix, block_type: BlockType, mode="move", scope=None) -> List[ASTDeclaration]:
def move_decls(cls, var_name, from_block, to_block, var_name_suffix: str, block_type: BlockType, mode="move") -> List[ASTDeclaration]:
"""Move or copy declarations from ``from_block`` to ``to_block``."""
from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor
assert mode in ["move", "copy"]

Expand All @@ -721,19 +731,19 @@ def move_decls(cls, var_name, from_block, to_block, var_name_suffix, block_type:
return []

decls = ASTUtils.get_declarations_from_block(var_name, from_block)
if var_name.endswith(var_name_suffix):
if var_name_suffix and var_name.endswith(var_name_suffix):
decls.extend(ASTUtils.get_declarations_from_block(removesuffix(var_name, var_name_suffix), from_block))

if decls:
Logger.log_message(None, -1, "Moving definition of " + var_name + " from synapse to neuron",
Logger.log_message(None, -1, ("Moving" if mode == "move" else "Copying") + " definition of " + var_name + " from synapse to neuron",
None, LoggingLevel.INFO)
for decl in decls:
if mode == "move":
from_block.declarations.remove(decl)
if mode == "copy":
decl = decl.clone()
assert len(decl.get_variables()) <= 1
if not decl.get_variables()[0].name.endswith(var_name_suffix):
if not decl.get_variables()[0].name.endswith(var_name_suffix) and var_name_suffix:
ASTUtils.add_suffix_to_decl_lhs(decl, suffix=var_name_suffix)
to_block.get_declarations().append(decl)
decl.update_scope(to_block.get_scope())
Expand Down Expand Up @@ -1456,7 +1466,11 @@ def collect_vars(_expr=None):
elif isinstance(_expr, ASTVariable):
var = _expr

if var:
symbol = None
if var and var.get_scope():
symbol = var.get_scope().resolve_to_symbol(var.get_complete_name(), SymbolKind.VARIABLE)

if var and symbol:
vars_used_.append(var)

expr.accept(ASTHigherOrderVisitor(lambda x: collect_vars(x)))
Expand Down Expand Up @@ -1513,6 +1527,7 @@ def recursive_dependent_variables_search(cls, vars: List[str], node: ASTNode) ->
if not _var in vars_checked:
var = _var
break

if not var:
# all variables checked
break
Expand Down
125 changes: 125 additions & 0 deletions tests/nest_tests/dopa_synapse_second_order_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# -*- coding: utf-8 -*-
#
# dopa_synapse_second_order_tests.py
#
# This file is part of NEST.
#
# Copyright (C) 2004 The NEST Initiative
#
# NEST is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# NEST is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.

import numpy as np
import os
import pytest

import nest

from pynestml.codegeneration.nest_tools import NESTTools
from pynestml.frontend.pynestml_frontend import generate_nest_target

try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.ticker
import matplotlib.pyplot as plt
TEST_PLOTS = True
except Exception:
TEST_PLOTS = False


class TestDopaSecondOrder:
r"""
Test second-order integration in a neuromodulated synapse.
"""

neuron_model_name = "iaf_psc_exp_nestml__with_dopa_synapse_second_order_nestml"
synapse_model_name = "dopa_synapse_second_order_nestml__with_iaf_psc_exp_nestml"

@pytest.fixture(scope="module", autouse=True)
def setUp(self):
r"""generate code for neuron and synapse and build NEST user module"""
files = [os.path.join("models", "neurons", "iaf_psc_exp.nestml"),
os.path.join("tests", "nest_tests", "resources", "dopa_synapse_second_order.nestml")]
input_path = [os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join(
os.pardir, os.pardir, s))) for s in files]
generate_nest_target(input_path=input_path,
logging_level="DEBUG",
module_name="nestmlmodule",
suffix="_nestml",
codegen_opts={"neuron_parent_class": "StructuralPlasticityNode",
"neuron_parent_class_include": "structural_plasticity_node.h",
"neuron_synapse_pairs": [{"neuron": "iaf_psc_exp",
"synapse": "dopa_synapse_second_order",
"vt_ports": ["dopa_spikes"]}]})

@pytest.mark.skipif(NESTTools.detect_nest_version().startswith("v2"),
reason="This test does not support NEST 2")
def test_nest_stdp_synapse(self):

resolution = .25 # [ms]
delay = 1. # [ms]
t_stop = 250. # [ms]

nest.ResetKernel()
nest.SetKernelStatus({"resolution": resolution})
nest.Install("nestmlmodule")

# create spike_generator
vt_sg = nest.Create("poisson_generator",
params={"rate": 20.})

# create volume transmitter
vt = nest.Create("volume_transmitter")
vt_parrot = nest.Create("parrot_neuron")
nest.Connect(vt_sg, vt_parrot)
nest.Connect(vt_parrot, vt, syn_spec={"synapse_model": "static_synapse",
"weight": 1.,
"delay": 1.}) # delay is ignored!
vt_gid = vt.get("global_id")

# set up custom synapse model
wr = nest.Create("weight_recorder")
nest.CopyModel(self.synapse_model_name, "stdp_nestml_rec",
{"weight_recorder": wr[0], "d": delay, "receptor_type": 0,
"vt": vt_gid})

# create parrot neurons and connect spike_generators
pre_neuron = nest.Create("parrot_neuron")
post_neuron = nest.Create(self.neuron_model_name)
nest.Connect(pre_neuron, post_neuron, syn_spec={"synapse_model": "stdp_nestml_rec"})

syn = nest.GetConnections(pre_neuron, post_neuron)
syn.tau_dopa = 25. # [ms]

log = {"t": [0.],
"dopa_rate": [syn.dopa_rate],
"dopa_rate_d": [syn.dopa_rate_d]}

n_timesteps = int(np.ceil(t_stop / resolution))
for timestep in range(n_timesteps):
nest.Simulate(resolution)
log["t"].append(nest.biological_time)
log["dopa_rate"].append(syn.dopa_rate)
log["dopa_rate_d"].append(syn.dopa_rate_d)

if TEST_PLOTS:
fig, ax = plt.subplots(nrows=2, dpi=300)
ax[0].plot(log["t"], log["dopa_rate"], label="dopa_rate")
ax[1].plot(log["t"], log["dopa_rate_d"], label="dopa_rate_d")
for _ax in ax:
_ax.legend()
fig.savefig("/tmp/dopa_synapse_second_order_tests.png")
plt.close(fig)

np.testing.assert_allclose(log["dopa_rate"][-1], 0.6834882070000989)
57 changes: 57 additions & 0 deletions tests/nest_tests/resources/dopa_synapse_second_order.nestml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
dopa_synapse_second_order
#########################


Description
+++++++++++

This model is used to test second-order integration of dopamine spikes.



Copyright statement
+++++++++++++++++++

This file is part of NEST.

Copyright (C) 2004 The NEST Initiative

NEST is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 2 of the License, or
(at your option) any later version.

NEST is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with NEST. If not, see <http://www.gnu.org/licenses/>.
"""
synapse dopa_synapse_second_order:
state:
dopa_rate real = 0.
dopa_rate_d real = 0.

parameters:
tau_dopa ms = 100 ms
d ms = 1 ms @nest::delay

equations:
dopa_rate' = dopa_rate_d / ms
dopa_rate_d' = -dopa_rate / tau_dopa**2 * ms - 2 * dopa_rate_d / tau_dopa

input:
pre_spikes real <- spike
dopa_spikes real <- spike

output:
spike

onReceive(dopa_spikes):
dopa_rate_d += 1. / tau_dopa

onReceive(pre_spikes):
deliver_spike(1., 1 ms)

0 comments on commit 924ec9b

Please sign in to comment.