Skip to content

Commit

Permalink
add support for forward Euler integrator
Browse files Browse the repository at this point in the history
  • Loading branch information
C.A.P. Linssen committed Aug 7, 2023
1 parent c0b21ea commit bb8190f
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 6 deletions.
15 changes: 10 additions & 5 deletions pynestml/codegeneration/nest_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class NESTCodeGenerator(CodeGenerator):
- **module_templates**: A list of the jinja templates or a relative path to a directory containing the templates related to generating the NEST module.
- **nest_version**: A string identifying the version of NEST Simulator to generate code for. The string corresponds to the NEST Simulator git repository tag or git branch name, for instance, ``"v2.20.2"`` or ``"master"``. The default is the empty string, which causes the NEST version to be automatically identified from the ``nest`` Python module.
- **solver**: A string identifying the preferred ODE solver. ``"analytic"`` for propagator solver preferred; fallback to numeric solver in case ODEs are not analytically solvable. Use ``"numeric"`` to disable analytic solver.
- **numeric_solver**: A string identifying the preferred numeric ODE solver. Supported are ``"rk45"`` and ``"forward-Euler"``.
- **redirect_build_output**: An optional boolean key for redirecting the build output. Setting the key to ``True``, two files will be created for redirecting the ``stdout`` and the ``stderr`. The ``target_path`` will be used as the default location for creating the two files.
- **build_output_dir**: An optional string key representing the new path where the files corresponding to the output of the build phase will be created. This key requires that the ``redirect_build_output`` is set to ``True``.
Expand All @@ -124,7 +125,8 @@ class NESTCodeGenerator(CodeGenerator):
"module_templates": ["setup"]
},
"nest_version": "",
"solver": "analytic"
"solver": "analytic",
"numeric_solver": "rk45"
}

def __init__(self, options: Optional[Mapping[str, Any]] = None):
Expand Down Expand Up @@ -438,6 +440,13 @@ def _get_model_namespace(self, astnode: ASTNeuronOrSynapse) -> Dict:
if kw.isupper():
namespace["PyNestMLLexer"][kw] = eval("PyNestMLLexer." + kw)

# ODE solving
namespace["uses_numeric_solver"] = astnode.get_name() in self.numeric_solver.keys() \
and self.numeric_solver[astnode.get_name()] is not None

if namespace["uses_numeric_solver"]:
namespace["numeric_solver"] = self.get_option("numeric_solver")

return namespace

def _get_synapse_model_namespace(self, synapse: ASTSynapse) -> Dict:
Expand Down Expand Up @@ -522,8 +531,6 @@ def _get_synapse_model_namespace(self, synapse: ASTSynapse) -> Dict:

namespace["propagators"] = self.analytic_solver[synapse.get_name()]["propagators"]

namespace["uses_numeric_solver"] = synapse.get_name() in self.numeric_solver.keys() \
and self.numeric_solver[synapse.get_name()] is not None
if namespace["uses_numeric_solver"]:
namespace["numeric_state_variables"] = self.numeric_solver[synapse.get_name()]["state_variables"]
namespace["variable_symbols"].update({sym: synapse.get_equations_blocks()[0].get_scope().resolve_to_symbol(
Expand Down Expand Up @@ -653,8 +660,6 @@ def _get_neuron_model_namespace(self, neuron: ASTNeuron) -> Dict:
_names = [ASTUtils.to_ode_toolbox_processed_name(var.get_complete_name()) for var in _names]
namespace["non_equations_state_variables"] = _names

namespace["uses_numeric_solver"] = neuron.get_name() in self.numeric_solver.keys() \
and self.numeric_solver[neuron.get_name()] is not None
if namespace["uses_numeric_solver"]:
namespace["numeric_state_variables_moved"] = []
if "paired_synapse" in dir(neuron):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,9 @@ std::vector< std::tuple< int, int > > {{neuronName}}::rport_to_nestml_buffer_idx
, spike_inputs_grid_sum_( std::vector< double >( NUM_SPIKE_RECEPTORS ) )
{%- endif %}
{%- if uses_numeric_solver %}
{%- if numeric_solver == "rk45" %}
, __s( nullptr ), __c( nullptr ), __e( nullptr )
{%- endif %}
{%- endif %}
{
// Initialization of the remaining members is deferred to init_buffers_().
Expand All @@ -201,7 +203,9 @@ std::vector< std::tuple< int, int > > {{neuronName}}::rport_to_nestml_buffer_idx
, spike_inputs_grid_sum_( std::vector< double >( NUM_SPIKE_RECEPTORS ) )
{%- endif %}
{%- if uses_numeric_solver %}
{%- if numeric_solver == "rk45" %}
, __s( nullptr ), __c( nullptr ), __e( nullptr )
{%- endif %}
{%- endif %}
{
// Initialization of the remaining members is deferred to init_buffers_().
Expand All @@ -221,9 +225,11 @@ std::vector< std::tuple< int, int > > {{neuronName}}::rport_to_nestml_buffer_idx
{%- endif %}

{%- if uses_numeric_solver %}
{%- if numeric_solver == "rk45" %}

// use a default "good enough" value for the absolute error. It can be adjusted via `node.set()`
P_.__gsl_error_tol = 1e-3;
{%- endif %}
{%- endif %}

{%- if parameter_vars_with_iv|length > 0 %}
Expand Down Expand Up @@ -325,6 +331,7 @@ std::vector< std::tuple< int, int > > {{neuronName}}::rport_to_nestml_buffer_idx
{{neuronName}}::~{{neuronName}}()
{
{%- if uses_numeric_solver %}
{%- if numeric_solver == "rk45" %}
// GSL structs may not have been allocated, so we need to protect destruction

if (B_.__s)
Expand All @@ -341,6 +348,7 @@ std::vector< std::tuple< int, int > > {{neuronName}}::rport_to_nestml_buffer_idx
{
gsl_odeiv_evolve_free( B_.__e );
}
{%- endif %}
{%- endif %}
}

Expand Down Expand Up @@ -389,6 +397,7 @@ void {{neuronName}}::init_buffers_()
clear_history();
{%- endif %}
{%- if uses_numeric_solver %}
{%- if numeric_solver == "rk45" %}

if ( not B_.__s )
{
Expand Down Expand Up @@ -423,6 +432,7 @@ void {{neuronName}}::init_buffers_()
B_.__sys.params = reinterpret_cast< void* >( this );
B_.__step = nest::Time::get_resolution().get_ms();
B_.__integration_step = nest::Time::get_resolution().get_ms();
{%- endif %}
{%- endif %}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,17 @@ along with NEST. If not, see <http://www.gnu.org/licenses/>.
{%- endif %}
{%- endif %}
{%- if uses_numeric_solver %}
{%- if numeric_solver == "rk45" %}

#ifndef HAVE_GSL
#error "The GSL library is required for neurons that require a numerical solver."
#error "The GSL library is required for the Runge-Kutta solver."
#endif

// External includes:
#include <gsl/gsl_errno.h>
#include <gsl/gsl_matrix.h>
#include <gsl/gsl_odeiv.h>
{%- endif %}
{%- endif %}

// Includes from nestkernel:
Expand Down Expand Up @@ -518,8 +520,10 @@ private:
{%- endfor %}
{%- endfilter %}
{%- if uses_numeric_solver %}
{%- if numeric_solver == "rk45" %}

double __gsl_error_tol;
{%- endif %}
{%- endif %}

/**
Expand Down Expand Up @@ -685,6 +689,7 @@ private:
{%- endfor %}

{%- if uses_numeric_solver %}
{%- if numeric_solver == "rk45" %}

// -----------------------------------------------------------------------
// GSL ODE solver data structures
Expand All @@ -701,6 +706,7 @@ private:
// it is safe to place both here.
double __step; //!< step size in ms
double __integration_step; //!< current integration time step, updated by GSL
{%- endif %}
{%- endif %}

};
Expand Down Expand Up @@ -919,10 +925,12 @@ inline void {{neuronName}}::get_status(DictionaryDatum &__d) const

(*__d)[nest::names::recordables] = recordablesMap_.get_list();
{%- if uses_numeric_solver %}
{%- if numeric_solver == "rk45" %}
def< double >(__d, nest::names::gsl_error_tol, P_.__gsl_error_tol);
if ( P_.__gsl_error_tol <= 0. ){
throw nest::BadProperty( "The gsl_error_tol must be strictly positive." );
}
{%- endif %}
{%- endif %}
}

Expand Down Expand Up @@ -977,11 +985,13 @@ inline void {{neuronName}}::set_status(const DictionaryDatum &__d)
{%- endfor %}

{% if uses_numeric_solver %}
{%- if numeric_solver == "rk45" %}
updateValue< double >(__d, nest::names::gsl_error_tol, P_.__gsl_error_tol);
if ( P_.__gsl_error_tol <= 0. )
{
throw nest::BadProperty( "The gsl_error_tol must be strictly positive." );
}
{%- endif %}
{%- endif %}

// recompute internal variables in case they are dependent on parameters or state that might have been updated in this call to set_status()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,9 @@ private:
{%- endfilter %}

{% if uses_numeric_solver %}
{%- if numeric_solver == "rk45" %}
double __gsl_error_tol;
{% endif %}
{% endif %}

/** Initialize parameters to their default values. */
Expand Down Expand Up @@ -1091,11 +1093,13 @@ updateValue<{{ declarations.print_variable_type(variable_symbol) }}>(__d, "{{ ne
}
{%- endfor %}
{% if uses_numeric_solver %}
{%- if numeric_solver == "rk45" %}

updateValue< double >(__d, nest::names::gsl_error_tol, P_.__gsl_error_tol);
if ( P_.__gsl_error_tol <= 0. ){
throw nest::BadProperty( "The gsl_error_tol must be strictly positive." );
}
{% endif %}
{% endif %}

// special treatment of NEST delay
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,9 @@ extern "C" inline int {{neuronName}}_dynamics(double, const double ode_state[],
{%- endfor %}
{%- endif %}

{%- if numeric_solver == "rk45" %}
return GSL_SUCCESS;
{%- else %}
return 0;
{%- endif %}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
all odes defined the neuron.
#}
{%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %}
{%- if numeric_solver == "rk45" %}
double __t = 0;
// numerical integration with adaptive step size control:
// ------------------------------------------------------
Expand Down Expand Up @@ -32,3 +33,13 @@ while ( __t < B_.__step )
throw nest::GSLSolverFailure( get_name(), status );
}
}
{%- elif numeric_solver == "forward-Euler" %}
double f[State_::STATE_VEC_SIZE];
{{neuronName}}_dynamics( get_t(), S_.ode_state, f, reinterpret_cast< void* >( this ) );
for (size_t i = 0; i < State_::STATE_VEC_SIZE; ++i)
{
S_.ode_state[i] += __resolution * f[i];
}
{%- else %}
{{ raise('Unknown numeric ODE solver requested.') }}
{%- endif %}
80 changes: 80 additions & 0 deletions tests/nest_tests/test_forward_euler_integrator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# -*- coding: utf-8 -*-
#
# test_forward_euler_integrator.py
#
# This file is part of NEST.
#
# Copyright (C) 2004 The NEST Initiative
#
# NEST is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# NEST is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.

import numpy as np
import os
import pytest

import nest

from pynestml.codegeneration.nest_tools import NESTTools
from pynestml.frontend.pynestml_frontend import generate_nest_target


class TestForwardEulerIntegrator:
"""
Tests the forward Euler integrator by comparing it to RK45.
"""

def generate_target(self, numeric_solver: str):
r"""Generate the neuron model code"""

# generate the "jit" model (co-generated neuron and synapse), that does not rely on ArchivingNode
files = [os.path.join("models", "neurons", "izhikevich.nestml")]
input_path = [os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join(
os.pardir, os.pardir, s))) for s in files]
generate_nest_target(input_path=input_path,
logging_level="DEBUG",
suffix="_" + numeric_solver.replace("-", "_") + "_nestml",
module_name="nestml" + numeric_solver.replace("-", "_") + "module",
codegen_opts={"numeric_solver": numeric_solver})

nest.Install("nestml" + numeric_solver.replace("-", "_") + "module")


def test_forward_euler_integrator(self):
self.generate_target("forward-Euler")
self.generate_target("rk45")

nest.ResetKernel()
nest.resolution = .001

nrn1 = nest.Create("izhikevich_rk45_nestml")
nrn2 = nest.Create("izhikevich_forward_Euler_nestml")

nrn1.I_e = 10.
nrn2.I_e = 10.

mm1 = nest.Create("multimeter")
mm1.set({"record_from": ["V_m"]})

mm2 = nest.Create("multimeter")
mm2.set({"record_from": ["V_m"]})

nest.Connect(mm1, nrn1)
nest.Connect(mm2, nrn2)

nest.Simulate(100.)

v_m1 = mm1.get("events")["V_m"]
v_m2 = mm2.get("events")["V_m"]

np.testing.assert_allclose(v_m1, v_m2, atol=2, rtol=0) # allow max 2 mV difference between the solutions

0 comments on commit bb8190f

Please sign in to comment.