diff --git a/doc/htmldoc/examples/index.rst b/doc/htmldoc/examples/index.rst index 4a48fabadc..c9696f1fae 100644 --- a/doc/htmldoc/examples/index.rst +++ b/doc/htmldoc/examples/index.rst @@ -212,10 +212,13 @@ PyNEST examples .. grid-item-card:: :doc:`../auto_examples/eprop_plasticity/index` :img-top: ../static/img/pynest/eprop_supervised_classification_infrastructure.png + * :doc:`/auto_examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation_bsshslm_2020` + * :doc:`/auto_examples/eprop_plasticity/eprop_supervised_regression_sine-waves_bsshslm_2020` + * :doc:`/auto_examples/eprop_plasticity/eprop_supervised_regression_handwriting_bsshslm_2020` + * :doc:`/auto_examples/eprop_plasticity/eprop_supervised_regression_lemniscate_bsshslm_2020` * :doc:`/auto_examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation` * :doc:`/auto_examples/eprop_plasticity/eprop_supervised_regression_sine-waves` - * :doc:`/auto_examples/eprop_plasticity/eprop_supervised_regression_handwriting` - * :doc:`/auto_examples/eprop_plasticity/eprop_supervised_regression_infinite-loop` + * :doc:`/auto_examples/eprop_plasticity/eprop_supervised_classification_neuromorphic_mnist` .. grid:: 1 1 2 3 diff --git a/doc/htmldoc/static/img/eprop_model_diagram.svg b/doc/htmldoc/static/img/eprop_model_diagram.svg new file mode 100644 index 0000000000..1b9a6d0fdb --- /dev/null +++ b/doc/htmldoc/static/img/eprop_model_diagram.svg @@ -0,0 +1,1494 @@ + + + + + + image/svg+xml + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Full e-prop model + + Models based on Bellec et al. (2020) + + + Models based on Korcsak-Gorzo et al. (2024) + + + Models existingin NEST + + eprop_iaf + eprop_iaf_adapt_bsshslm_2020 + eprop_iaf_adapt + eprop_readout_bsshslm_2020 + eprop_readout + eprop_synapse_bsshlsm_2020 + eprop_synapse + eprop_learning_signal_connection_bsshslm_2020 + eprop_learning_signal_connection + eprop_iaf_psc_delta + iaf_psc_delta + + + + + adaptive threshold + + + e-prop plasticity + + + eprop_iaf_bsshslm_2020 + + + + + + + + + + + adaptive threshold + + + different reset mechanisms & refractory dynamics + + + + + + + + + + + + + + + + + biological features + diff --git a/doc/htmldoc/whats_new/v3.7/index.rst b/doc/htmldoc/whats_new/v3.7/index.rst index a168d2145b..44732f7e29 100644 --- a/doc/htmldoc/whats_new/v3.7/index.rst +++ b/doc/htmldoc/whats_new/v3.7/index.rst @@ -17,12 +17,13 @@ E-prop plasticity in NEST ------------------------- Another new NEST feature is eligibility propagation (e-prop) [1]_, a local and -online learning algorithm for recurrent spiking neural networks (RSNNs) that -serves as a biologically plausible approximation to backpropagation through time -(BPTT). It relies on eligibility traces and neuron-specific learning signals to -compute gradients without the need for error propagation backward in time. This -approach aligns with the brain's learning mechanisms and offers a strong -candidate for efficient training of RSNNs in low-power neuromorphic hardware. +online learning algorithm for recurrent spiking neural networks (RSNNs) that is +biologically plausible and approaches the performance of backpropagation through +time (BPTT). It relies on eligibility traces and neuron-specific learning +signals to compute gradients without the need for error propagation backward in +time. This approach aligns with the brain's learning mechanisms and offers a +strong candidate for efficient training of RSNNs in low-power neuromorphic +hardware. For further information, see: diff --git a/models/eprop_iaf.cpp b/models/eprop_iaf.cpp new file mode 100644 index 0000000000..f95ae67f0c --- /dev/null +++ b/models/eprop_iaf.cpp @@ -0,0 +1,449 @@ +/* + * eprop_iaf.cpp + * + * This file is part of NEST. + * + * Copyright (C) 2004 The NEST Initiative + * + * NEST is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * NEST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with NEST. If not, see . + * + */ + +// nest models +#include "eprop_iaf.h" + +// C++ +#include + +// libnestutil +#include "dict_util.h" +#include "numerics.h" + +// nestkernel +#include "exceptions.h" +#include "kernel_manager.h" +#include "nest_impl.h" +#include "universal_data_logger_impl.h" + +// sli +#include "dictutils.h" + +namespace nest +{ + +void +register_eprop_iaf( const std::string& name ) +{ + register_node_model< eprop_iaf >( name ); +} + +/* ---------------------------------------------------------------- + * Recordables map + * ---------------------------------------------------------------- */ + +RecordablesMap< eprop_iaf > eprop_iaf::recordablesMap_; + +template <> +void +RecordablesMap< eprop_iaf >::create() +{ + insert_( names::learning_signal, &eprop_iaf::get_learning_signal_ ); + insert_( names::surrogate_gradient, &eprop_iaf::get_surrogate_gradient_ ); + insert_( names::V_m, &eprop_iaf::get_v_m_ ); +} + +/* ---------------------------------------------------------------- + * Default constructors for parameters, state, and buffers + * ---------------------------------------------------------------- */ + +eprop_iaf::Parameters_::Parameters_() + : C_m_( 250.0 ) + , c_reg_( 0.0 ) + , E_L_( -70.0 ) + , f_target_( 0.01 ) + , beta_( 1.0 ) + , gamma_( 0.3 ) + , I_e_( 0.0 ) + , regular_spike_arrival_( true ) + , surrogate_gradient_function_( "piecewise_linear" ) + , t_ref_( 2.0 ) + , tau_m_( 10.0 ) + , V_min_( -std::numeric_limits< double >::max() ) + , V_th_( -55.0 - E_L_ ) + , kappa_( 0.97 ) + , kappa_reg_( 0.97 ) + , eprop_isi_trace_cutoff_( 1000.0 ) +{ +} + +eprop_iaf::State_::State_() + : learning_signal_( 0.0 ) + , r_( 0 ) + , surrogate_gradient_( 0.0 ) + , i_in_( 0.0 ) + , v_m_( 0.0 ) + , z_( 0.0 ) + , z_in_( 0.0 ) +{ +} + +eprop_iaf::Buffers_::Buffers_( eprop_iaf& n ) + : logger_( n ) +{ +} + +eprop_iaf::Buffers_::Buffers_( const Buffers_&, eprop_iaf& n ) + : logger_( n ) +{ +} + +/* ---------------------------------------------------------------- + * Getter and setter functions for parameters and state + * ---------------------------------------------------------------- */ + +void +eprop_iaf::Parameters_::get( DictionaryDatum& d ) const +{ + def< double >( d, names::C_m, C_m_ ); + def< double >( d, names::c_reg, c_reg_ ); + def< double >( d, names::E_L, E_L_ ); + def< double >( d, names::f_target, f_target_ ); + def< double >( d, names::beta, beta_ ); + def< double >( d, names::gamma, gamma_ ); + def< double >( d, names::I_e, I_e_ ); + def< bool >( d, names::regular_spike_arrival, regular_spike_arrival_ ); + def< std::string >( d, names::surrogate_gradient_function, surrogate_gradient_function_ ); + def< double >( d, names::t_ref, t_ref_ ); + def< double >( d, names::tau_m, tau_m_ ); + def< double >( d, names::V_min, V_min_ + E_L_ ); + def< double >( d, names::V_th, V_th_ + E_L_ ); + def< double >( d, names::kappa, kappa_ ); + def< double >( d, names::kappa_reg, kappa_reg_ ); + def< double >( d, names::eprop_isi_trace_cutoff, eprop_isi_trace_cutoff_ ); +} + +double +eprop_iaf::Parameters_::set( const DictionaryDatum& d, Node* node ) +{ + // if leak potential is changed, adjust all variables defined relative to it + const double ELold = E_L_; + updateValueParam< double >( d, names::E_L, E_L_, node ); + const double delta_EL = E_L_ - ELold; + + V_th_ -= updateValueParam< double >( d, names::V_th, V_th_, node ) ? E_L_ : delta_EL; + V_min_ -= updateValueParam< double >( d, names::V_min, V_min_, node ) ? E_L_ : delta_EL; + + updateValueParam< double >( d, names::C_m, C_m_, node ); + updateValueParam< double >( d, names::c_reg, c_reg_, node ); + + if ( updateValueParam< double >( d, names::f_target, f_target_, node ) ) + { + f_target_ /= 1000.0; // convert from spikes/s to spikes/ms + } + + updateValueParam< double >( d, names::beta, beta_, node ); + updateValueParam< double >( d, names::gamma, gamma_, node ); + updateValueParam< double >( d, names::I_e, I_e_, node ); + updateValueParam< bool >( d, names::regular_spike_arrival, regular_spike_arrival_, node ); + updateValueParam< std::string >( d, names::surrogate_gradient_function, surrogate_gradient_function_, node ); + updateValueParam< double >( d, names::t_ref, t_ref_, node ); + updateValueParam< double >( d, names::tau_m, tau_m_, node ); + updateValueParam< double >( d, names::kappa, kappa_, node ); + updateValueParam< double >( d, names::kappa_reg, kappa_reg_, node ); + updateValueParam< double >( d, names::eprop_isi_trace_cutoff, eprop_isi_trace_cutoff_, node ); + + if ( C_m_ <= 0 ) + { + throw BadProperty( "Membrane capacitance C_m > 0 required." ); + } + + if ( c_reg_ < 0 ) + { + throw BadProperty( "Firing rate regularization coefficient c_reg ≥ 0 required." ); + } + + if ( f_target_ < 0 ) + { + throw BadProperty( "Firing rate regularization target rate f_target ≥ 0 required." ); + } + + if ( tau_m_ <= 0 ) + { + throw BadProperty( "Membrane time constant tau_m > 0 required." ); + } + + if ( t_ref_ < 0 ) + { + throw BadProperty( "Refractory time t_ref ≥ 0 required." ); + } + + if ( V_th_ < V_min_ ) + { + throw BadProperty( "Spike threshold voltage V_th ≥ minimal voltage V_min required." ); + } + + if ( kappa_ < 0.0 or kappa_ > 1.0 ) + { + throw BadProperty( "Eligibility trace low-pass filter kappa from range [0, 1] required." ); + } + + if ( kappa_reg_ < 0.0 or kappa_reg_ > 1.0 ) + { + throw BadProperty( "Firing rate low-pass filter for regularization kappa_reg from range [0, 1] required." ); + } + + if ( eprop_isi_trace_cutoff_ < 0.0 ) + { + throw BadProperty( "Cutoff of integration of eprop trace between spikes eprop_isi_trace_cutoff ≥ 0 required." ); + } + + return delta_EL; +} + +void +eprop_iaf::State_::get( DictionaryDatum& d, const Parameters_& p ) const +{ + def< double >( d, names::V_m, v_m_ + p.E_L_ ); + def< double >( d, names::surrogate_gradient, surrogate_gradient_ ); + def< double >( d, names::learning_signal, learning_signal_ ); +} + +void +eprop_iaf::State_::set( const DictionaryDatum& d, const Parameters_& p, double delta_EL, Node* node ) +{ + v_m_ -= updateValueParam< double >( d, names::V_m, v_m_, node ) ? p.E_L_ : delta_EL; +} + +/* ---------------------------------------------------------------- + * Default and copy constructor for node + * ---------------------------------------------------------------- */ + +eprop_iaf::eprop_iaf() + : EpropArchivingNodeRecurrent() + , P_() + , S_() + , B_( *this ) +{ + recordablesMap_.create(); +} + +eprop_iaf::eprop_iaf( const eprop_iaf& n ) + : EpropArchivingNodeRecurrent( n ) + , P_( n.P_ ) + , S_( n.S_ ) + , B_( n.B_, *this ) +{ +} + +/* ---------------------------------------------------------------- + * Node initialization functions + * ---------------------------------------------------------------- */ + +void +eprop_iaf::init_buffers_() +{ + B_.spikes_.clear(); // includes resize + B_.currents_.clear(); // includes resize + B_.logger_.reset(); // includes resize +} + +void +eprop_iaf::pre_run_hook() +{ + B_.logger_.init(); // ensures initialization in case multimeter connected after Simulate + + V_.RefractoryCounts_ = Time( Time::ms( P_.t_ref_ ) ).get_steps(); + V_.eprop_isi_trace_cutoff_steps_ = Time( Time::ms( P_.eprop_isi_trace_cutoff_ ) ).get_steps(); + + compute_surrogate_gradient_ = select_surrogate_gradient( P_.surrogate_gradient_function_ ); + + // calculate the entries of the propagator matrix for the evolution of the state vector + + const double dt = Time::get_resolution().get_ms(); + + V_.P_v_m_ = std::exp( -dt / P_.tau_m_ ); + V_.P_i_in_ = P_.tau_m_ / P_.C_m_ * ( 1.0 - V_.P_v_m_ ); + V_.P_z_in_ = P_.regular_spike_arrival_ ? 1.0 : 1.0 - V_.P_v_m_; +} + +long +eprop_iaf::get_shift() const +{ + return offset_gen_ + delay_in_rec_; +} + +bool +eprop_iaf::is_eprop_recurrent_node() const +{ + return true; +} + +/* ---------------------------------------------------------------- + * Update function + * ---------------------------------------------------------------- */ + +void +eprop_iaf::update( Time const& origin, const long from, const long to ) +{ + for ( long lag = from; lag < to; ++lag ) + { + const long t = origin.get_steps() + lag; + + if ( S_.r_ > 0 ) + { + --S_.r_; + } + + S_.z_in_ = B_.spikes_.get_value( lag ); + + S_.v_m_ = V_.P_i_in_ * S_.i_in_ + V_.P_z_in_ * S_.z_in_ + V_.P_v_m_ * S_.v_m_; + S_.v_m_ -= P_.V_th_ * S_.z_; + S_.v_m_ = std::max( S_.v_m_, P_.V_min_ ); + + S_.z_ = 0.0; + + S_.surrogate_gradient_ = ( this->*compute_surrogate_gradient_ )( S_.r_, S_.v_m_, P_.V_th_, P_.beta_, P_.gamma_ ); + + if ( S_.v_m_ >= P_.V_th_ and S_.r_ == 0 ) + { + SpikeEvent se; + kernel().event_delivery_manager.send( *this, se, lag ); + + S_.z_ = 1.0; + S_.r_ = V_.RefractoryCounts_; + } + + append_new_eprop_history_entry( t ); + write_surrogate_gradient_to_history( t, S_.surrogate_gradient_ ); + write_firing_rate_reg_to_history( t, S_.z_, P_.f_target_, P_.kappa_reg_, P_.c_reg_ ); + + S_.learning_signal_ = get_learning_signal_from_history( t, false ); + + S_.i_in_ = B_.currents_.get_value( lag ) + P_.I_e_; + + B_.logger_.record_data( t ); + } +} + +/* ---------------------------------------------------------------- + * Event handling functions + * ---------------------------------------------------------------- */ + +void +eprop_iaf::handle( SpikeEvent& e ) +{ + assert( e.get_delay_steps() > 0 ); + + B_.spikes_.add_value( + e.get_rel_delivery_steps( kernel().simulation_manager.get_slice_origin() ), e.get_weight() * e.get_multiplicity() ); +} + +void +eprop_iaf::handle( CurrentEvent& e ) +{ + assert( e.get_delay_steps() > 0 ); + + B_.currents_.add_value( + e.get_rel_delivery_steps( kernel().simulation_manager.get_slice_origin() ), e.get_weight() * e.get_current() ); +} + +void +eprop_iaf::handle( LearningSignalConnectionEvent& e ) +{ + for ( auto it_event = e.begin(); it_event != e.end(); ) + { + const long time_step = e.get_stamp().get_steps(); + const double weight = e.get_weight(); + const double error_signal = e.get_coeffvalue( it_event ); // get_coeffvalue advances iterator + const double learning_signal = weight * error_signal; + + write_learning_signal_to_history( time_step, learning_signal, false ); + } +} + +void +eprop_iaf::handle( DataLoggingRequest& e ) +{ + B_.logger_.handle( e ); +} + +void +eprop_iaf::compute_gradient( const long t_spike, + const long t_spike_previous, + double& z_previous_buffer, + double& z_bar, + double& e_bar, + double& e_bar_reg, + double& epsilon, + double& weight, + const CommonSynapseProperties& cp, + WeightOptimizer* optimizer ) +{ + double e = 0.0; // eligibility trace + double z = 0.0; // spiking variable + double z_current_buffer = 1.0; // buffer containing the spike that triggered the current integration + double psi = 0.0; // surrogate gradient + double L = 0.0; // learning signal + double firing_rate_reg = 0.0; // firing rate regularization + double grad = 0.0; // gradient + + const EpropSynapseCommonProperties& ecp = static_cast< const EpropSynapseCommonProperties& >( cp ); + const auto optimize_each_step = ( *ecp.optimizer_cp_ ).optimize_each_step_; + + auto eprop_hist_it = get_eprop_history( t_spike_previous - 1 ); + + const long t_compute_until = std::min( t_spike_previous + V_.eprop_isi_trace_cutoff_steps_, t_spike ); + + for ( long t = t_spike_previous; t < t_compute_until; ++t, ++eprop_hist_it ) + { + z = z_previous_buffer; + z_previous_buffer = z_current_buffer; + z_current_buffer = 0.0; + + psi = eprop_hist_it->surrogate_gradient_; + L = eprop_hist_it->learning_signal_; + firing_rate_reg = eprop_hist_it->firing_rate_reg_; + + z_bar = V_.P_v_m_ * z_bar + V_.P_z_in_ * z; + e = psi * z_bar; + e_bar = P_.kappa_ * e_bar + ( 1.0 - P_.kappa_ ) * e; + e_bar_reg = P_.kappa_reg_ * e_bar_reg + ( 1.0 - P_.kappa_reg_ ) * e; + + if ( optimize_each_step ) + { + grad = L * e_bar + firing_rate_reg * e_bar_reg; + weight = optimizer->optimized_weight( *ecp.optimizer_cp_, t, grad, weight ); + } + else + { + grad += L * e_bar + firing_rate_reg * e_bar_reg; + } + } + + if ( not optimize_each_step ) + { + weight = optimizer->optimized_weight( *ecp.optimizer_cp_, t_compute_until, grad, weight ); + } + + const long cutoff_to_spike_interval = t_spike - t_compute_until; + + if ( cutoff_to_spike_interval > 0 ) + { + z_bar *= std::pow( V_.P_v_m_, cutoff_to_spike_interval ); + e_bar *= std::pow( P_.kappa_, cutoff_to_spike_interval ); + e_bar_reg *= std::pow( P_.kappa_reg_, cutoff_to_spike_interval ); + } +} + +} // namespace nest diff --git a/models/eprop_iaf.h b/models/eprop_iaf.h new file mode 100644 index 0000000000..a531def0fb --- /dev/null +++ b/models/eprop_iaf.h @@ -0,0 +1,676 @@ +/* + * eprop_iaf.h + * + * This file is part of NEST. + * + * Copyright (C) 2004 The NEST Initiative + * + * NEST is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * NEST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with NEST. If not, see . + * + */ + +#ifndef EPROP_IAF_H +#define EPROP_IAF_H + +// nestkernel +#include "connection.h" +#include "eprop_archiving_node.h" +#include "eprop_archiving_node_impl.h" +#include "eprop_synapse.h" +#include "event.h" +#include "nest_types.h" +#include "ring_buffer.h" +#include "universal_data_logger.h" + +namespace nest +{ + +/* BeginUserDocs: neuron, e-prop plasticity, current-based, integrate-and-fire + +Short description ++++++++++++++++++ + +Current-based leaky integrate-and-fire neuron model with delta-shaped +postsynaptic currents for e-prop plasticity + +Description ++++++++++++ + +``eprop_iaf`` is an implementation of a leaky integrate-and-fire +neuron model with delta-shaped postsynaptic currents used for eligibility +propagation (e-prop) plasticity. + +E-prop plasticity was originally introduced and implemented in TensorFlow in [1]_. + +.. note:: + The neuron dynamics of the ``eprop_iaf`` model (excluding e-prop + plasticity) are similar to the neuron dynamics of the ``iaf_psc_delta`` model, + with minor differences, such as the propagator of the post-synaptic current + and the voltage reset upon a spike. + +The membrane voltage time course :math:`v_j^t` of the neuron :math:`j` is given by: + +.. math:: + v_j^t &= \alpha v_j^{t-1} + \zeta \sum_{i \neq j} W_{ji}^\text{rec} z_i^{t-1} + + \zeta \sum_i W_{ji}^\text{in} x_i^t - z_j^{t-1} v_\text{th} \,, \\ + \alpha &= e^{ -\frac{ \Delta t }{ \tau_\text{m} } } \,, \\ + \zeta &= + \begin{cases} + 1 \\ + 1 - \alpha + \end{cases} \,, \\ + +where :math:`W_{ji}^\text{rec}` and :math:`W_{ji}^\text{in}` are the recurrent and +input synaptic weight matrices, and :math:`z_i^{t-1}` is the recurrent presynaptic +state variable, while :math:`x_i^t` represents the input at time :math:`t`. + +Descriptions of further parameters and variables can be found in the table below. + +The spike state variable is expressed by a Heaviside function: + +.. math:: + z_j^t = H \left( v_j^t - v_\text{th} \right) \,. \\ + +If the membrane voltage crosses the threshold voltage :math:`v_\text{th}`, a spike is +emitted and the membrane voltage is reduced by :math:`v_\text{th}` in the next +time step. After the time step of the spike emission, the neuron is not +able to spike for an absolute refractory period :math:`t_\text{ref}`. + +An additional state variable and the corresponding differential equation +represents a piecewise constant external current. + +See the documentation on the :doc:`iaf_psc_delta<../models/iaf_psc_delta/>` neuron model +for more information on the integration of the subthreshold dynamics. + +The change of the synaptic weight is calculated from the gradient :math:`g^t` of +the loss :math:`E^t` with respect to the synaptic weight :math:`W_{ji}`: +:math:`\frac{ \text{d} E^t }{ \text{d} W_{ij} }` +which depends on the presynaptic +spikes :math:`z_i^{t-2}`, the surrogate gradient or pseudo-derivative +of the spike state variable with respect to the postsynaptic membrane +voltage :math:`\psi_j^{t-1}` (the product of which forms the eligibility +trace :math:`e_{ji}^{t-1}`), and the learning signal :math:`L_j^t` emitted +by the readout neurons. + +.. start_surrogate-gradient-functions + +Surrogate gradients help overcome the challenge of the spiking function's +non-differentiability, facilitating the use of gradient-based learning +techniques such as e-prop. The non-existent derivative of the spiking +variable with respect to the membrane voltage, +:math:`\frac{\partial z^t_j}{ \partial v^t_j}`, can be effectively +replaced with a variety of surrogate gradient functions, as detailed in +various studies (see, e.g., [3]_). NEST currently provides four +different surrogate gradient functions: + +1. A piecewise linear function used among others in [1]_: + +.. math:: + \psi_j^t = \frac{ \gamma }{ v_\text{th} } \text{max} + \left( 0, 1-\beta \left| \frac{ v_j^t - v_\text{th} }{ v_\text{th} }\right| \right) \,. \\ + +2. An exponential function used in [4]_: + +.. math:: + \psi_j^t = \gamma \exp \left( -\beta \left| v_j^t - v_\text{th} \right| \right) \,. \\ + +3. The derivative of a fast sigmoid function used in [5]_: + +.. math:: + \psi_j^t = \gamma \left( 1 + \beta \left| v_j^t - v_\text{th} \right| \right)^2 \,. \\ + +4. An arctan function used in [6]_: + +.. math:: + \psi_j^t = \frac{\gamma}{\pi} \frac{1}{ 1 + \left( \beta \pi \left( v_j^t - v_\text{th} \right) \right)^2 } \,. \\ + +.. end_surrogate-gradient-functions + +In the interval between two presynaptic spikes, the gradient is calculated +at each time step until the cutoff time point. This computation occurs over +the time range: + +:math:`t \in \left[ t_\text{spk,prev}, \min \left( t_\text{spk,prev} + \Delta t_\text{c}, t_\text{spk,curr} \right) +\right]`. + +Here, :math:`t_\text{spk,prev}` represents the time of the previous spike that +passed the synapse, while :math:`t_\text{spk,curr}` is the time of the +current spike, which triggers the application of the learning rule and the +subsequent synaptic weight update. The cutoff :math:`\Delta t_\text{c}` +defines the maximum allowable interval for integration between spikes. +The expression for the gradient is given by: + +.. math:: + \frac{ \text{d} E^t }{ \text{d} W_{ji} } &= L_j^t \bar{e}_{ji}^{t-1} \,, \\ + e_{ji}^{t-1} &= \psi_j^{t-1} \bar{z}_i^{t-2} \,, \\ + +The eligibility trace and the presynaptic spike trains are low-pass filtered +with the following exponential kernels: + +.. math:: + \bar{e}_{ji}^t &= \mathcal{F}_\kappa \left( e_{ji}^t \right) + = \kappa \bar{e}_{ji}^{t-1} + \left( 1 - \kappa \right) e_{ji}^t \,, \\ + \bar{z}_i^t &= \mathcal{F}_\alpha \left( z_{i}^t \right)= \alpha \bar{z}_i^{t-1} + \zeta z_i^t \,. \\ + +Furthermore, a firing rate regularization mechanism keeps the exponential moving average of the postsynaptic +neuron's firing rate :math:`f_j^{\text{ema},t}` close to a target firing rate +:math:`f^\text{target}`. The gradient :math:`g_\text{reg}^t` of the regularization loss :math:`E_\text{reg}^t` +with respect to the synaptic weight :math:`W_{ji}` is given by: + +.. math:: + \frac{ \text{d} E_\text{reg}^t }{ \text{d} W_{ji}} + &\approx c_\text{reg} \left( f^{\text{ema},t}_j - f^\text{target} \right) \bar{e}_{ji}^t \,, \\ + f^{\text{ema},t}_j &= \mathcal{F}_{\kappa_\text{reg}} \left( \frac{z_j^t}{\Delta t} \right) + = \kappa_\text{reg} f^{\text{ema},t-1}_j + \left( 1 - \kappa_\text{reg} \right) \frac{z_j^t}{\Delta t} \,, \\ + +where :math:`c_\text{reg}` is a constant scaling factor. + +The overall gradient is given by the addition of the two gradients. + +As a last step for every round in the loop over the time steps :math:`t`, the new weight is retrieved by feeding the +current gradient :math:`g^t` to the optimizer (see :doc:`weight_optimizer<../models/weight_optimizer/>` +for more information on the available optimizers): + +.. math:: + w^t = \text{optimizer} \left( t, g^t, w^{t-1} \right) \,. \\ + +After the loop has terminated, the filtered dynamic variables of e-prop are propagated from the end of the cutoff until +the next spike: + +.. math:: + p &= \text{max} \left( 0, t_\text{s}^{t} - \left( t_\text{s}^{t-1} + {\Delta t}_\text{c} \right) \right) \,, \\ + \bar{e}_{ji}^{t+p} &= \bar{e}_{ji}^t \kappa^p \,, \\ + \bar{z}_i^{t+p} &= \bar{z}_i^t \alpha^p \,. \\ + +For more information on e-prop plasticity, see the documentation on the other e-prop models: + + * :doc:`eprop_iaf_adapt<../models/eprop_iaf_adapt/>` + * :doc:`eprop_readout<../models/eprop_readout/>` + * :doc:`eprop_synapse<../models/eprop_synapse/>` + * :doc:`eprop_learning_signal_connection<../models/eprop_learning_signal_connection/>` + +Details on the event-based NEST implementation of e-prop can be found in [2]_. + +Parameters +++++++++++ + +The following parameters can be set in the status dictionary. + +=========================== ======= ======================= ================ =================================== +**Neuron parameters** +---------------------------------------------------------------------------------------------------------------- +Parameter Unit Math equivalent Default Description +=========================== ======= ======================= ================ =================================== +``C_m`` pF :math:`C_\text{m}` 250.0 Capacitance of the membrane +``E_L`` mV :math:`E_\text{L}` -70.0 Leak / resting membrane potential +``I_e`` pA :math:`I_\text{e}` 0.0 Constant external input current +``regular_spike_arrival`` Boolean ``True`` If ``True``, the input spikes + arrive at the end of the time step, + if ``False`` at the beginning + (determines PSC scale) +``t_ref`` ms :math:`t_\text{ref}` 2.0 Duration of the refractory period +``tau_m`` ms :math:`\tau_\text{m}` 10.0 Time constant of the membrane +``V_min`` mV :math:`v_\text{min}` negative maximum Absolute lower bound of the + value membrane voltage + representable by + a ``double`` + type in C++ +``V_th`` mV :math:`v_\text{th}` -55.0 Spike threshold voltage +=========================== ======= ======================= ================ =================================== + +=============================== ======= =========================== ================== ========================= +**E-prop parameters** +---------------------------------------------------------------------------------------------------------------- +Parameter Unit Math equivalent Default Description +=============================== ======= =========================== ================== ========================= +``c_reg`` :math:`c_\text{reg}` 0.0 Coefficient of firing + rate regularization +``eprop_isi_trace_cutoff`` ms :math:`{\Delta t}_\text{c}` maximum value Cutoff for integration of + representable e-prop update between two + by a ``long`` spikes + type in C++ +``f_target`` Hz :math:`f^\text{target}` 10.0 Target firing rate of + rate regularization +``kappa`` :math:`\kappa` 0.97 Low-pass filter of the + eligibility trace +``kappa_reg`` :math:`\kappa_\text{reg}` 0.97 Low-pass filter of the + firing rate for + regularization +``beta`` :math:`\beta` 1.0 Width scaling of + surrogate gradient / + pseudo-derivative of + membrane voltage +``gamma`` :math:`\gamma` 0.3 Height scaling of + surrogate gradient / + pseudo-derivative of + membrane voltage +``surrogate_gradient_function`` :math:`\psi` "piecewise_linear" Surrogate gradient / + pseudo-derivative + function + ["piecewise_linear", + "exponential", + "fast_sigmoid_derivative" + , "arctan"] +=============================== ======= =========================== ================== ========================= + +Recordables ++++++++++++ + +The following state variables evolve during simulation and can be recorded. + +================== ==== =============== ============= ======================== +**Neuron state variables and recordables** +------------------------------------------------------------------------------ +State variable Unit Math equivalent Initial value Description +================== ==== =============== ============= ======================== +``V_m`` mV :math:`v_j` -70.0 Membrane voltage +================== ==== =============== ============= ======================== + +====================== ==== =============== ============= ========================================= +**E-prop state variables and recordables** +--------------------------------------------------------------------------------------------------- +State variable Unit Math equivalent Initial value Description +====================== ==== =============== ============= ========================================= +``learning_signal`` pA :math:`L_j` 0.0 Learning signal +``surrogate_gradient`` :math:`\psi_j` 0.0 Surrogate gradient / pseudo-derivative of + membrane voltage +====================== ==== =============== ============= ========================================= + +Usage ++++++ + +This model can only be used in combination with the other e-prop models +and the network architecture requires specific wiring, input, and output. +The usage is demonstrated in several +:doc:`supervised regression and classification tasks <../auto_examples/eprop_plasticity/index>` +reproducing among others the original proof-of-concept tasks in [1]_. + +References +++++++++++ + +.. [1] Bellec G, Scherr F, Subramoney F, Hajek E, Salaj D, Legenstein R, + Maass W (2020). A solution to the learning dilemma for recurrent + networks of spiking neurons. Nature Communications, 11:3625. + https://doi.org/10.1038/s41467-020-17236-y + +.. [2] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Plesser HE, + Dahmen D, Bolten M, Van Albada SJ*, Diesmann M*. Event-based + implementation of eligibility propagation (in preparation) + +.. start_surrogate-gradient-references + +.. [3] Neftci EO, Mostafa H, Zenke F (2019). Surrogate Gradient Learning in + Spiking Neural Networks. IEEE Signal Processing Magazine, 36(6), 51-63. + https://doi.org/10.1109/MSP.2019.2931595 + +.. [4] Shrestha SB, Orchard G (2018). SLAYER: Spike Layer Error Reassignment in + Time. Advances in Neural Information Processing Systems, 31:1412-1421. + https://proceedings.neurips.cc/paper_files/paper/2018/hash/82.. rubric:: References + +.. [5] Zenke F, Ganguli S (2018). SuperSpike: Supervised Learning in Multilayer + Spiking Neural Networks. Neural Computation, 30:1514–1541. + https://doi.org/10.1162/neco_a_01086 + +.. [6] Fang W, Yu Z, Chen Y, Huang T, Masquelier T, Tian Y (2021). Deep residual + learning in spiking neural networks. Advances in Neural Information + Processing Systems, 34:21056–21069. + https://proceedings.neurips.cc/paper/2021/hash/afe434653a898da20044041262b3ac74-Abstract.html + +.. end_surrogate-gradient-references + +Sends ++++++ + +SpikeEvent + +Receives +++++++++ + +SpikeEvent, CurrentEvent, LearningSignalConnectionEvent, DataLoggingRequest + +See also +++++++++ + +Examples using this model ++++++++++++++++++++++++++ + +.. listexamples:: eprop_iaf + +EndUserDocs */ + +void register_eprop_iaf( const std::string& name ); + +/** + * @brief Class implementing a LIF neuron model for e-prop plasticity with additional biological features. + * + * Class implementing a current-based leaky integrate-and-fire neuron model with delta-shaped postsynaptic currents for + * e-prop plasticity according to Bellec et al. (2020) with additional biological features described in + * Korcsak-Gorzo, Stapmanns, and Espinoza Valverde et al. (in preparation). + */ +class eprop_iaf : public EpropArchivingNodeRecurrent +{ + +public: + //! Default constructor. + eprop_iaf(); + + //! Copy constructor. + eprop_iaf( const eprop_iaf& ); + + using Node::handle; + using Node::handles_test_event; + + size_t send_test_event( Node&, size_t, synindex, bool ) override; + + void handle( SpikeEvent& ) override; + void handle( CurrentEvent& ) override; + void handle( LearningSignalConnectionEvent& ) override; + void handle( DataLoggingRequest& ) override; + + size_t handles_test_event( SpikeEvent&, size_t ) override; + size_t handles_test_event( CurrentEvent&, size_t ) override; + size_t handles_test_event( LearningSignalConnectionEvent&, size_t ) override; + size_t handles_test_event( DataLoggingRequest&, size_t ) override; + + void get_status( DictionaryDatum& ) const override; + void set_status( const DictionaryDatum& ) override; + +private: + void init_buffers_() override; + void pre_run_hook() override; + + void update( Time const&, const long, const long ) override; + + void compute_gradient( const long, + const long, + double&, + double&, + double&, + double&, + double&, + double&, + const CommonSynapseProperties&, + WeightOptimizer* ) override; + + long get_shift() const override; + bool is_eprop_recurrent_node() const override; + long get_eprop_isi_trace_cutoff() const override; + + //! Pointer to member function selected for computing the surrogate gradient. + surrogate_gradient_function compute_surrogate_gradient_; + + //! Map for storing a static set of recordables. + friend class RecordablesMap< eprop_iaf >; + + //! Logger for universal data supporting the data logging request / reply mechanism. Populated with a recordables map. + friend class UniversalDataLogger< eprop_iaf >; + + //! Structure of parameters. + struct Parameters_ + { + //! Capacitance of the membrane (pF). + double C_m_; + + //! Coefficient of firing rate regularization. + double c_reg_; + + //! Leak / resting membrane potential (mV). + double E_L_; + + //! Target firing rate of rate regularization (spikes/s). + double f_target_; + + //! Width scaling of surrogate gradient / pseudo-derivative of membrane voltage. + double beta_; + + //! Height scaling of surrogate gradient / pseudo-derivative of membrane voltage. + double gamma_; + + //! Constant external input current (pA). + double I_e_; + + //! If True, the input spikes arrive at the beginning of the time step, if False at the end (determines PSC scale). + bool regular_spike_arrival_; + + //! Surrogate gradient / pseudo-derivative function of the membrane voltage ["piecewise_linear", "exponential", + //! "fast_sigmoid_derivative", "arctan"] + std::string surrogate_gradient_function_; + + //! Duration of the refractory period (ms). + double t_ref_; + + //! Time constant of the membrane (ms). + double tau_m_; + + //! Absolute lower bound of the membrane voltage relative to the leak membrane potential (mV). + double V_min_; + + //! Spike threshold voltage relative to the leak membrane potential (mV). + double V_th_; + + //! Low-pass filter of the eligibility trace. + double kappa_; + + //! Low-pass filter of the firing rate for regularization. + double kappa_reg_; + + //! Time interval from the previous spike until the cutoff of e-prop update integration between two spikes (ms). + double eprop_isi_trace_cutoff_; + + //! Default constructor. + Parameters_(); + + //! Get the parameters and their values. + void get( DictionaryDatum& ) const; + + //! Set the parameters and throw errors in case of invalid values. + double set( const DictionaryDatum&, Node* ); + }; + + //! Structure of state variables. + struct State_ + { + //! Learning signal. Sum of weighted error signals coming from the readout neurons. + double learning_signal_; + + //! Number of remaining refractory steps. + int r_; + + //! Surrogate gradient / pseudo-derivative of the membrane voltage. + double surrogate_gradient_; + + //! Input current (pA). + double i_in_; + + //! Membrane voltage relative to the leak membrane potential (mV). + double v_m_; + + //! Binary spike state variable - 1.0 if the neuron has spiked in the previous time step and 0.0 otherwise. + double z_; + + //! Binary input spike state variable - 1.0 if the neuron has spiked in the previous time step and 0.0 otherwise. + double z_in_; + + //! Default constructor. + State_(); + + //! Get the state variables and their values. + void get( DictionaryDatum&, const Parameters_& ) const; + + //! Set the state variables. + void set( const DictionaryDatum&, const Parameters_&, double, Node* ); + }; + + //! Structure of buffers. + struct Buffers_ + { + //! Default constructor. + Buffers_( eprop_iaf& ); + + //! Copy constructor. + Buffers_( const Buffers_&, eprop_iaf& ); + + //! Buffer for incoming spikes. + RingBuffer spikes_; + + //! Buffer for incoming currents. + RingBuffer currents_; + + //! Logger for universal data. + UniversalDataLogger< eprop_iaf > logger_; + }; + + //! Structure of internal variables. + struct Variables_ + { + //! Propagator matrix entry for evolving the membrane voltage (mathematical symbol "alpha" in user documentation). + double P_v_m_; + + //! Propagator matrix entry for evolving the incoming spike state variables (mathematical symbol "zeta" in user + //! documentation). + double P_z_in_; + + //! Propagator matrix entry for evolving the incoming currents. + double P_i_in_; + + //! Total refractory steps. + int RefractoryCounts_; + + //! Time steps from the previous spike until the cutoff of e-prop update integration between two spikes. + long eprop_isi_trace_cutoff_steps_; + }; + + //! Get the current value of the membrane voltage. + double + get_v_m_() const + { + return S_.v_m_ + P_.E_L_; + } + + //! Get the current value of the surrogate gradient. + double + get_surrogate_gradient_() const + { + return S_.surrogate_gradient_; + } + + //! Get the current value of the learning signal. + double + get_learning_signal_() const + { + return S_.learning_signal_; + } + + // the order in which the structure instances are defined is important for speed + + //! Structure of parameters. + Parameters_ P_; + + //! Structure of state variables. + State_ S_; + + //! Structure of internal variables. + Variables_ V_; + + //! Structure of buffers. + Buffers_ B_; + + //! Map storing a static set of recordables. + static RecordablesMap< eprop_iaf > recordablesMap_; +}; + +inline long +eprop_iaf::get_eprop_isi_trace_cutoff() const +{ + return V_.eprop_isi_trace_cutoff_steps_; +} + +inline size_t +eprop_iaf::send_test_event( Node& target, size_t receptor_type, synindex, bool ) +{ + SpikeEvent e; + e.set_sender( *this ); + return target.handles_test_event( e, receptor_type ); +} + +inline size_t +eprop_iaf::handles_test_event( SpikeEvent&, size_t receptor_type ) +{ + if ( receptor_type != 0 ) + { + throw UnknownReceptorType( receptor_type, get_name() ); + } + + return 0; +} + +inline size_t +eprop_iaf::handles_test_event( CurrentEvent&, size_t receptor_type ) +{ + if ( receptor_type != 0 ) + { + throw UnknownReceptorType( receptor_type, get_name() ); + } + + return 0; +} + +inline size_t +eprop_iaf::handles_test_event( LearningSignalConnectionEvent&, size_t receptor_type ) +{ + if ( receptor_type != 0 ) + { + throw UnknownReceptorType( receptor_type, get_name() ); + } + + return 0; +} + +inline size_t +eprop_iaf::handles_test_event( DataLoggingRequest& dlr, size_t receptor_type ) +{ + if ( receptor_type != 0 ) + { + throw UnknownReceptorType( receptor_type, get_name() ); + } + + return B_.logger_.connect_logging_device( dlr, recordablesMap_ ); +} + +inline void +eprop_iaf::get_status( DictionaryDatum& d ) const +{ + P_.get( d ); + S_.get( d, P_ ); + ( *d )[ names::recordables ] = recordablesMap_.get_list(); +} + +inline void +eprop_iaf::set_status( const DictionaryDatum& d ) +{ + // temporary copies in case of errors + Parameters_ ptmp = P_; + State_ stmp = S_; + + // make sure that ptmp and stmp consistent - throw BadProperty if not + const double delta_EL = ptmp.set( d, this ); + stmp.set( d, ptmp, delta_EL, this ); + + P_ = ptmp; + S_ = stmp; +} + +} // namespace nest + +#endif // EPROP_IAF_H diff --git a/models/eprop_iaf_adapt.cpp b/models/eprop_iaf_adapt.cpp new file mode 100644 index 0000000000..1534b4d772 --- /dev/null +++ b/models/eprop_iaf_adapt.cpp @@ -0,0 +1,490 @@ +/* + * eprop_iaf_adapt.cpp + * + * This file is part of NEST. + * + * Copyright (C) 2004 The NEST Initiative + * + * NEST is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * NEST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with NEST. If not, see . + * + */ + +// nest models +#include "eprop_iaf_adapt.h" + +// C++ +#include + +// libnestutil +#include "dict_util.h" +#include "numerics.h" + +// nestkernel +#include "exceptions.h" +#include "kernel_manager.h" +#include "nest_impl.h" +#include "universal_data_logger_impl.h" + +// sli +#include "dictutils.h" + +namespace nest +{ + +void +register_eprop_iaf_adapt( const std::string& name ) +{ + register_node_model< eprop_iaf_adapt >( name ); +} + +/* ---------------------------------------------------------------- + * Recordables map + * ---------------------------------------------------------------- */ + +RecordablesMap< eprop_iaf_adapt > eprop_iaf_adapt::recordablesMap_; + +template <> +void +RecordablesMap< eprop_iaf_adapt >::create() +{ + insert_( names::adaptation, &eprop_iaf_adapt::get_adaptation_ ); + insert_( names::V_th_adapt, &eprop_iaf_adapt::get_v_th_adapt_ ); + insert_( names::learning_signal, &eprop_iaf_adapt::get_learning_signal_ ); + insert_( names::surrogate_gradient, &eprop_iaf_adapt::get_surrogate_gradient_ ); + insert_( names::V_m, &eprop_iaf_adapt::get_v_m_ ); +} + +/* ---------------------------------------------------------------- + * Default constructors for parameters, state, and buffers + * ---------------------------------------------------------------- */ + +eprop_iaf_adapt::Parameters_::Parameters_() + : adapt_beta_( 1.0 ) + , adapt_tau_( 10.0 ) + , C_m_( 250.0 ) + , c_reg_( 0.0 ) + , E_L_( -70.0 ) + , f_target_( 0.01 ) + , beta_( 1.0 ) + , gamma_( 0.3 ) + , I_e_( 0.0 ) + , regular_spike_arrival_( true ) + , surrogate_gradient_function_( "piecewise_linear" ) + , t_ref_( 2.0 ) + , tau_m_( 10.0 ) + , V_min_( -std::numeric_limits< double >::max() ) + , V_th_( -55.0 - E_L_ ) + , kappa_( 0.97 ) + , kappa_reg_( 0.97 ) + , eprop_isi_trace_cutoff_( 1000.0 ) +{ +} + +eprop_iaf_adapt::State_::State_() + : adapt_( 0.0 ) + , v_th_adapt_( 15.0 ) + , learning_signal_( 0.0 ) + , r_( 0 ) + , surrogate_gradient_( 0.0 ) + , i_in_( 0.0 ) + , v_m_( 0.0 ) + , z_( 0.0 ) + , z_in_( 0.0 ) +{ +} + +eprop_iaf_adapt::Buffers_::Buffers_( eprop_iaf_adapt& n ) + : logger_( n ) +{ +} + +eprop_iaf_adapt::Buffers_::Buffers_( const Buffers_&, eprop_iaf_adapt& n ) + : logger_( n ) +{ +} + +/* ---------------------------------------------------------------- + * Getter and setter functions for parameters and state + * ---------------------------------------------------------------- */ + +void +eprop_iaf_adapt::Parameters_::get( DictionaryDatum& d ) const +{ + def< double >( d, names::adapt_beta, adapt_beta_ ); + def< double >( d, names::adapt_tau, adapt_tau_ ); + def< double >( d, names::C_m, C_m_ ); + def< double >( d, names::c_reg, c_reg_ ); + def< double >( d, names::E_L, E_L_ ); + def< double >( d, names::f_target, f_target_ ); + def< double >( d, names::beta, beta_ ); + def< double >( d, names::gamma, gamma_ ); + def< double >( d, names::I_e, I_e_ ); + def< bool >( d, names::regular_spike_arrival, regular_spike_arrival_ ); + def< std::string >( d, names::surrogate_gradient_function, surrogate_gradient_function_ ); + def< double >( d, names::t_ref, t_ref_ ); + def< double >( d, names::tau_m, tau_m_ ); + def< double >( d, names::V_min, V_min_ + E_L_ ); + def< double >( d, names::V_th, V_th_ + E_L_ ); + def< double >( d, names::kappa, kappa_ ); + def< double >( d, names::kappa_reg, kappa_reg_ ); + def< double >( d, names::eprop_isi_trace_cutoff, eprop_isi_trace_cutoff_ ); +} + +double +eprop_iaf_adapt::Parameters_::set( const DictionaryDatum& d, Node* node ) +{ + // if leak potential is changed, adjust all variables defined relative to it + const double ELold = E_L_; + updateValueParam< double >( d, names::E_L, E_L_, node ); + const double delta_EL = E_L_ - ELold; + + V_th_ -= updateValueParam< double >( d, names::V_th, V_th_, node ) ? E_L_ : delta_EL; + V_min_ -= updateValueParam< double >( d, names::V_min, V_min_, node ) ? E_L_ : delta_EL; + + updateValueParam< double >( d, names::adapt_beta, adapt_beta_, node ); + updateValueParam< double >( d, names::adapt_tau, adapt_tau_, node ); + updateValueParam< double >( d, names::C_m, C_m_, node ); + updateValueParam< double >( d, names::c_reg, c_reg_, node ); + + if ( updateValueParam< double >( d, names::f_target, f_target_, node ) ) + { + f_target_ /= 1000.0; // convert from spikes/s to spikes/ms + } + + updateValueParam< double >( d, names::beta, beta_, node ); + updateValueParam< double >( d, names::gamma, gamma_, node ); + updateValueParam< double >( d, names::I_e, I_e_, node ); + updateValueParam< bool >( d, names::regular_spike_arrival, regular_spike_arrival_, node ); + updateValueParam< std::string >( d, names::surrogate_gradient_function, surrogate_gradient_function_, node ); + updateValueParam< double >( d, names::t_ref, t_ref_, node ); + updateValueParam< double >( d, names::tau_m, tau_m_, node ); + updateValueParam< double >( d, names::kappa, kappa_, node ); + updateValueParam< double >( d, names::kappa_reg, kappa_reg_, node ); + updateValueParam< double >( d, names::eprop_isi_trace_cutoff, eprop_isi_trace_cutoff_, node ); + + if ( adapt_beta_ < 0 ) + { + throw BadProperty( "Threshold adaptation prefactor adapt_beta ≥ 0 required." ); + } + + if ( adapt_tau_ <= 0 ) + { + throw BadProperty( "Threshold adaptation time constant adapt_tau > 0 required." ); + } + + if ( C_m_ <= 0 ) + { + throw BadProperty( "Membrane capacitance C_m > 0 required." ); + } + + if ( c_reg_ < 0 ) + { + throw BadProperty( "Firing rate regularization coefficient c_reg ≥ 0 required." ); + } + + if ( f_target_ < 0 ) + { + throw BadProperty( "Firing rate regularization target rate f_target ≥ 0 required." ); + } + + if ( tau_m_ <= 0 ) + { + throw BadProperty( "Membrane time constant tau_m > 0 required." ); + } + + if ( t_ref_ < 0 ) + { + throw BadProperty( "Refractory time t_ref ≥ 0 required." ); + } + + if ( V_th_ < V_min_ ) + { + throw BadProperty( "Spike threshold voltage V_th ≥ minimal voltage V_min required." ); + } + + if ( kappa_ < 0.0 or kappa_ > 1.0 ) + { + throw BadProperty( "Eligibility trace low-pass filter kappa from range [0, 1] required." ); + } + + if ( kappa_reg_ < 0.0 or kappa_reg_ > 1.0 ) + { + throw BadProperty( "Firing rate low-pass filter for regularization kappa_reg from range [0, 1] required." ); + } + + if ( eprop_isi_trace_cutoff_ < 0.0 ) + { + throw BadProperty( "Cutoff of integration of eprop trace between spikes eprop_isi_trace_cutoff ≥ 0 required." ); + } + + return delta_EL; +} + +void +eprop_iaf_adapt::State_::get( DictionaryDatum& d, const Parameters_& p ) const +{ + def< double >( d, names::adaptation, adapt_ ); + def< double >( d, names::V_m, v_m_ + p.E_L_ ); + def< double >( d, names::V_th_adapt, v_th_adapt_ + p.E_L_ ); + def< double >( d, names::surrogate_gradient, surrogate_gradient_ ); + def< double >( d, names::learning_signal, learning_signal_ ); +} + +void +eprop_iaf_adapt::State_::set( const DictionaryDatum& d, const Parameters_& p, double delta_EL, Node* node ) +{ + v_m_ -= updateValueParam< double >( d, names::V_m, v_m_, node ) ? p.E_L_ : delta_EL; + + // adaptive threshold can only be set indirectly via the adaptation variable + if ( updateValueParam< double >( d, names::adaptation, adapt_, node ) ) + { + // if E_L changed in this SetStatus call, p.V_th_ has been adjusted and no further action is needed + v_th_adapt_ = p.V_th_ + p.adapt_beta_ * adapt_; + } + else + { + // adjust voltage to change in E_L + v_th_adapt_ -= delta_EL; + } +} + +/* ---------------------------------------------------------------- + * Default and copy constructor for node + * ---------------------------------------------------------------- */ + +eprop_iaf_adapt::eprop_iaf_adapt() + : EpropArchivingNodeRecurrent() + , P_() + , S_() + , B_( *this ) +{ + recordablesMap_.create(); +} + +eprop_iaf_adapt::eprop_iaf_adapt( const eprop_iaf_adapt& n ) + : EpropArchivingNodeRecurrent( n ) + , P_( n.P_ ) + , S_( n.S_ ) + , B_( n.B_, *this ) +{ +} + +/* ---------------------------------------------------------------- + * Node initialization functions + * ---------------------------------------------------------------- */ + +void +eprop_iaf_adapt::init_buffers_() +{ + B_.spikes_.clear(); // includes resize + B_.currents_.clear(); // includes resize + B_.logger_.reset(); // includes resize +} + +void +eprop_iaf_adapt::pre_run_hook() +{ + B_.logger_.init(); // ensures initialization in case multimeter connected after Simulate + + V_.RefractoryCounts_ = Time( Time::ms( P_.t_ref_ ) ).get_steps(); + V_.eprop_isi_trace_cutoff_steps_ = Time( Time::ms( P_.eprop_isi_trace_cutoff_ ) ).get_steps(); + + compute_surrogate_gradient_ = select_surrogate_gradient( P_.surrogate_gradient_function_ ); + + // calculate the entries of the propagator matrix for the evolution of the state vector + + const double dt = Time::get_resolution().get_ms(); + + V_.P_v_m_ = std::exp( -dt / P_.tau_m_ ); + V_.P_i_in_ = P_.tau_m_ / P_.C_m_ * ( 1.0 - V_.P_v_m_ ); + V_.P_z_in_ = P_.regular_spike_arrival_ ? 1.0 : 1.0 - V_.P_v_m_; + V_.P_adapt_ = std::exp( -dt / P_.adapt_tau_ ); +} + +long +eprop_iaf_adapt::get_shift() const +{ + return offset_gen_ + delay_in_rec_; +} + +bool +eprop_iaf_adapt::is_eprop_recurrent_node() const +{ + return true; +} + +/* ---------------------------------------------------------------- + * Update function + * ---------------------------------------------------------------- */ + +void +eprop_iaf_adapt::update( Time const& origin, const long from, const long to ) +{ + for ( long lag = from; lag < to; ++lag ) + { + const long t = origin.get_steps() + lag; + + if ( S_.r_ > 0 ) + { + --S_.r_; + } + + S_.z_in_ = B_.spikes_.get_value( lag ); + + S_.v_m_ = V_.P_i_in_ * S_.i_in_ + V_.P_z_in_ * S_.z_in_ + V_.P_v_m_ * S_.v_m_; + S_.v_m_ -= P_.V_th_ * S_.z_; + S_.v_m_ = std::max( S_.v_m_, P_.V_min_ ); + + S_.adapt_ = V_.P_adapt_ * S_.adapt_ + S_.z_; + S_.v_th_adapt_ = P_.V_th_ + P_.adapt_beta_ * S_.adapt_; + + S_.z_ = 0.0; + + S_.surrogate_gradient_ = + ( this->*compute_surrogate_gradient_ )( S_.r_, S_.v_m_, S_.v_th_adapt_, P_.beta_, P_.gamma_ ); + + if ( S_.v_m_ >= S_.v_th_adapt_ and S_.r_ == 0 ) + { + SpikeEvent se; + kernel().event_delivery_manager.send( *this, se, lag ); + + S_.z_ = 1.0; + S_.r_ = V_.RefractoryCounts_; + } + + append_new_eprop_history_entry( t ); + write_surrogate_gradient_to_history( t, S_.surrogate_gradient_ ); + write_firing_rate_reg_to_history( t, S_.z_, P_.f_target_, P_.kappa_reg_, P_.c_reg_ ); + + S_.learning_signal_ = get_learning_signal_from_history( t, false ); + + S_.i_in_ = B_.currents_.get_value( lag ) + P_.I_e_; + + B_.logger_.record_data( t ); + } +} + +/* ---------------------------------------------------------------- + * Event handling functions + * ---------------------------------------------------------------- */ + +void +eprop_iaf_adapt::handle( SpikeEvent& e ) +{ + assert( e.get_delay_steps() > 0 ); + + B_.spikes_.add_value( + e.get_rel_delivery_steps( kernel().simulation_manager.get_slice_origin() ), e.get_weight() * e.get_multiplicity() ); +} + +void +eprop_iaf_adapt::handle( CurrentEvent& e ) +{ + assert( e.get_delay_steps() > 0 ); + + B_.currents_.add_value( + e.get_rel_delivery_steps( kernel().simulation_manager.get_slice_origin() ), e.get_weight() * e.get_current() ); +} + +void +eprop_iaf_adapt::handle( LearningSignalConnectionEvent& e ) +{ + for ( auto it_event = e.begin(); it_event != e.end(); ) + { + const long time_step = e.get_stamp().get_steps(); + const double weight = e.get_weight(); + const double error_signal = e.get_coeffvalue( it_event ); // get_coeffvalue advances iterator + const double learning_signal = weight * error_signal; + + write_learning_signal_to_history( time_step, learning_signal, false ); + } +} + +void +eprop_iaf_adapt::handle( DataLoggingRequest& e ) +{ + B_.logger_.handle( e ); +} + +void +eprop_iaf_adapt::compute_gradient( const long t_spike, + const long t_spike_previous, + double& z_previous_buffer, + double& z_bar, + double& e_bar, + double& e_bar_reg, + double& epsilon, + double& weight, + const CommonSynapseProperties& cp, + WeightOptimizer* optimizer ) +{ + double e = 0.0; // eligibility trace + double z = 0.0; // spiking variable + double z_current_buffer = 1.0; // buffer containing the spike that triggered the current integration + double psi = 0.0; // surrogate gradient + double L = 0.0; // learning signal + double firing_rate_reg = 0.0; // firing rate regularization + double grad = 0.0; // gradient + + const EpropSynapseCommonProperties& ecp = static_cast< const EpropSynapseCommonProperties& >( cp ); + const auto optimize_each_step = ( *ecp.optimizer_cp_ ).optimize_each_step_; + + auto eprop_hist_it = get_eprop_history( t_spike_previous - 1 ); + + const long t_compute_until = std::min( t_spike_previous + V_.eprop_isi_trace_cutoff_steps_, t_spike ); + + for ( long t = t_spike_previous; t < t_compute_until; ++t, ++eprop_hist_it ) + { + z = z_previous_buffer; + z_previous_buffer = z_current_buffer; + z_current_buffer = 0.0; + + psi = eprop_hist_it->surrogate_gradient_; + L = eprop_hist_it->learning_signal_; + firing_rate_reg = eprop_hist_it->firing_rate_reg_; + + z_bar = V_.P_v_m_ * z_bar + V_.P_z_in_ * z; + e = psi * ( z_bar - P_.adapt_beta_ * epsilon ); + epsilon = V_.P_adapt_ * epsilon + e; + e_bar = P_.kappa_ * e_bar + ( 1.0 - P_.kappa_ ) * e; + e_bar_reg = P_.kappa_reg_ * e_bar_reg + ( 1.0 - P_.kappa_reg_ ) * e; + + if ( optimize_each_step ) + { + grad = L * e_bar + firing_rate_reg * e_bar_reg; + weight = optimizer->optimized_weight( *ecp.optimizer_cp_, t, grad, weight ); + } + else + { + grad += L * e_bar + firing_rate_reg * e_bar_reg; + } + } + + if ( not optimize_each_step ) + { + weight = optimizer->optimized_weight( *ecp.optimizer_cp_, t_compute_until, grad, weight ); + } + + const long cutoff_to_spike_interval = t_spike - t_compute_until; + + if ( cutoff_to_spike_interval > 0 ) + { + z_bar *= std::pow( V_.P_v_m_, cutoff_to_spike_interval ); + e_bar *= std::pow( P_.kappa_, cutoff_to_spike_interval ); + e_bar_reg *= std::pow( P_.kappa_reg_, cutoff_to_spike_interval ); + epsilon *= std::pow( V_.P_adapt_, cutoff_to_spike_interval ); + } +} + +} // namespace nest diff --git a/models/eprop_iaf_adapt.h b/models/eprop_iaf_adapt.h new file mode 100644 index 0000000000..07972840a5 --- /dev/null +++ b/models/eprop_iaf_adapt.h @@ -0,0 +1,673 @@ +/* + * eprop_iaf_adapt.h + * + * This file is part of NEST. + * + * Copyright (C) 2004 The NEST Initiative + * + * NEST is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * NEST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with NEST. If not, see . + * + */ + +#ifndef EPROP_IAF_ADAPT_H +#define EPROP_IAF_ADAPT_H + +// nestkernel +#include "connection.h" +#include "eprop_archiving_node.h" +#include "eprop_archiving_node_impl.h" +#include "eprop_synapse.h" +#include "event.h" +#include "nest_types.h" +#include "ring_buffer.h" +#include "universal_data_logger.h" + +namespace nest +{ + +/* BeginUserDocs: neuron, e-prop plasticity, current-based, integrate-and-fire, adaptive threshold + +Short description ++++++++++++++++++ + +Current-based leaky integrate-and-fire neuron model with delta-shaped +postsynaptic currents and threshold adaptation for e-prop plasticity + +Description ++++++++++++ + +``eprop_iaf_adapt`` is an implementation of a leaky integrate-and-fire +neuron model with delta-shaped postsynaptic currents and threshold adaptation +used for eligibility propagation (e-prop) plasticity. + +E-prop plasticity was originally introduced and implemented in TensorFlow in [1]_. + + .. note:: + The neuron dynamics of the ``eprop_iaf_adapt`` model (excluding + e-prop plasticity and the threshold adaptation) are similar to the neuron + dynamics of the ``iaf_psc_delta`` model, with minor differences, such as the + propagator of the post-synaptic current and the voltage reset upon a spike. + +The membrane voltage time course :math:`v_j^t` of the neuron :math:`j` is given by: + +.. math:: + v_j^t &= \alpha v_j^{t-1} + \zeta \sum_{i \neq j} W_{ji}^\text{rec} z_i^{t-1} + + \zeta \sum_i W_{ji}^\text{in} x_i^t - z_j^{t-1} v_\text{th} \,, \\ + \alpha &= e^{ -\frac{ \Delta t }{ \tau_\text{m} } } \,, \\ + \zeta &= + \begin{cases} + 1 \\ + 1 - \alpha + \end{cases} \,, \\ + +where :math:`W_{ji}^\text{rec}` and :math:`W_{ji}^\text{in}` are the recurrent and +input synaptic weight matrices, and :math:`z_i^{t-1}` is the recurrent presynaptic +state variable, while :math:`x_i^t` represents the input at time :math:`t`. + +Descriptions of further parameters and variables can be found in the table below. + +The threshold adaptation is given by: + +.. math:: + A_j^t &= v_\text{th} + \beta a_j^t \,, \\ + a_j^t &= \rho a_j^{t-1} + z_j^{t-1} \,, \\ + \rho &= e^{-\frac{ \Delta t }{ \tau_\text{a} }} \,. \\ + +The spike state variable is expressed by a Heaviside function: + +.. math:: + z_j^t = H \left( v_j^t - A_j^t \right) \,. \\ + +If the membrane voltage crosses the adaptive threshold voltage :math:`A_j^t`, a spike is +emitted and the membrane voltage is reduced by :math:`v_\text{th}` in the next +time step. After the time step of the spike emission, the neuron is not +able to spike for an absolute refractory period :math:`t_\text{ref}`. + +An additional state variable and the corresponding differential equation +represents a piecewise constant external current. + +See the documentation on the :doc:`iaf_psc_delta<../models/iaf_psc_delta/>` neuron model +for more information on the integration of the subthreshold dynamics. + +The change of the synaptic weight is calculated from the gradient :math:`g^t` of +the loss :math:`E^t` with respect to the synaptic weight :math:`W_{ji}`: +:math:`\frac{ \text{d} E^t }{ \text{d} W_{ij} }` +which depends on the presynaptic +spikes :math:`z_i^{t-2}`, the surrogate gradient or pseudo-derivative +of the spike state variable with respect to the postsynaptic membrane +voltage :math:`\psi_j^{t-1}` (the product of which forms the eligibility +trace :math:`e_{ji}^{t-1}`), and the learning signal :math:`L_j^t` emitted +by the readout neurons. + +.. include:: ../models/eprop_iaf.rst + :start-after: .. start_surrogate-gradient-functions + :end-before: .. end_surrogate-gradient-functions + +In the interval between two presynaptic spikes, the gradient is calculated +at each time step until the cutoff time point. This computation occurs over +the time range: + +:math:`t \in \left[ t_\text{spk,prev}, \min \left( t_\text{spk,prev} + \Delta t_\text{c}, t_\text{spk,curr} \right) +\right]`. + +Here, :math:`t_\text{spk,prev}` represents the time of the previous spike that +passed the synapse, while :math:`t_\text{spk,curr}` is the time of the +current spike, which triggers the application of the learning rule and the +subsequent synaptic weight update. The cutoff :math:`\Delta t_\text{c}` +defines the maximum allowable interval for integration between spikes. +The expression for the gradient is given by: + +.. math:: + \frac{ \text{d} E^t }{ \text{d} W_{ji} } &= L_j^t \bar{e}_{ji}^{t-1} \,, \\ + e_{ji}^{t-1} &= \psi_j^{t-1} \left( \bar{z}_i^{t-2} - \beta \epsilon_{ji,a}^{t-2} \right) \,, \\ + \epsilon^{t-2}_{ji,\text{a}} &= e_{ji}^{t-1} + \rho \epsilon_{ji,a}^{t-3} \,. \\ + +The eligibility trace and the presynaptic spike trains are low-pass filtered +with the following exponential kernels: + +.. math:: + \bar{e}_{ji}^t &= \mathcal{F}_\kappa \left( e_{ji}^t \right) + = \kappa \bar{e}_{ji}^{t-1} + \left( 1 - \kappa \right) e_{ji}^t \,, \\ + \bar{z}_i^t &= \mathcal{F}_\alpha \left( z_{i}^t \right)= \alpha \bar{z}_i^{t-1} + \zeta z_i^t \,. \\ + +Furthermore, a firing rate regularization mechanism keeps the exponential moving average of the postsynaptic +neuron's firing rate :math:`f_j^{\text{ema},t}` close to a target firing rate +:math:`f^\text{target}`. The gradient :math:`g_\text{reg}^t` of the regularization loss :math:`E_\text{reg}^t` +with respect to the synaptic weight :math:`W_{ji}` is given by: + +.. math:: + \frac{ \text{d} E_\text{reg}^t }{ \text{d} W_{ji}} + &\approx c_\text{reg} \left( f^{\text{ema},t}_j - f^\text{target} \right) \bar{e}_{ji}^t \,, \\ + f^{\text{ema},t}_j &= \mathcal{F}_{\kappa_\text{reg}} \left( \frac{z_j^t}{\Delta t} \right) + = \kappa_\text{reg} f^{\text{ema},t-1}_j + \left( 1 - \kappa_\text{reg} \right) \frac{z_j^t}{\Delta t} \,, \\ + +where :math:`c_\text{reg}` is a constant scaling factor. + +The overall gradient is given by the addition of the two gradients. + +As a last step for every round in the loop over the time steps :math:`t`, the new weight is retrieved by feeding the +current gradient :math:`g^t` to the optimizer (see :doc:`weight_optimizer<../models/weight_optimizer/>` +for more information on the available optimizers): + +.. math:: + w^t = \text{optimizer} \left( t, g^t, w^{t-1} \right) \,. \\ + +After the loop has terminated, the filtered dynamic variables of e-prop are propagated from the end of the cutoff until +the next spike: + +.. math:: + p &= \text{max} \left( 0, t_\text{s}^{t} - \left( t_\text{s}^{t-1} + {\Delta t}_\text{c} \right) \right) \,, \\ + \bar{e}_{ji}^{t+p} &= \bar{e}_{ji}^t \kappa^p \,, \\ + \bar{z}_i^{t+p} &= \bar{z}_i^t \alpha^p \,, \\ + \epsilon^{t+p} &= \epsilon^t \rho^p \,. \\ + +For more information on e-prop plasticity, see the documentation on the other e-prop models: + + * :doc:`eprop_iaf<../models/eprop_iaf/>` + * :doc:`eprop_readout<../models/eprop_readout/>` + * :doc:`eprop_synapse<../models/eprop_synapse/>` + * :doc:`eprop_learning_signal_connection<../models/eprop_learning_signal_connection/>` + +Details on the event-based NEST implementation of e-prop can be found in [2]_. + +Parameters +++++++++++ + +The following parameters can be set in the status dictionary. + +=========================== ======= ======================= ================ =================================== +**Neuron parameters** +---------------------------------------------------------------------------------------------------------------- +Parameter Unit Math equivalent Default Description +=========================== ======= ======================= ================ =================================== +``adapt_beta`` :math:`\beta` 1.0 Prefactor of the threshold + adaptation +``adapt_tau`` ms :math:`\tau_\text{a}` 10.0 Time constant of the threshold + adaptation +``C_m`` pF :math:`C_\text{m}` 250.0 Capacitance of the membrane +``E_L`` mV :math:`E_\text{L}` -70.0 Leak / resting membrane potential +``I_e`` pA :math:`I_\text{e}` 0.0 Constant external input current +``regular_spike_arrival`` Boolean ``True`` If ``True``, the input spikes + arrive at the end of the time step, + if ``False`` at the beginning + (determines PSC scale) +``t_ref`` ms :math:`t_\text{ref}` 2.0 Duration of the refractory period +``tau_m`` ms :math:`\tau_\text{m}` 10.0 Time constant of the membrane +``V_min`` mV :math:`v_\text{min}` negative maximum Absolute lower bound of the + value membrane voltage + representable by + a ``double`` + type in C++ +``V_th`` mV :math:`v_\text{th}` -55.0 Spike threshold voltage +=========================== ======= ======================= ================ =================================== + +=============================== ======= =========================== ================== ========================= +**E-prop parameters** +---------------------------------------------------------------------------------------------------------------- +Parameter Unit Math equivalent Default Description +=============================== ======= =========================== ================== ========================= +``c_reg`` :math:`c_\text{reg}` 0.0 Coefficient of firing + rate regularization +``eprop_isi_trace_cutoff`` ms :math:`{\Delta t}_\text{c}` maximum value Cutoff for integration of + representable e-prop update between two + by a ``long`` spikes + type in C++ +``f_target`` Hz :math:`f^\text{target}` 10.0 Target firing rate of + rate regularization +``kappa`` :math:`\kappa` 0.97 Low-pass filter of the + eligibility trace +``kappa_reg`` :math:`\kappa_\text{reg}` 0.97 Low-pass filter of the + firing rate for + regularization +``beta`` :math:`\beta` 1.0 Width scaling of + surrogate gradient / + pseudo-derivative of + membrane voltage +``gamma`` :math:`\gamma` 0.3 Height scaling of + surrogate gradient / + pseudo-derivative of + membrane voltage +``surrogate_gradient_function`` :math:`\psi` "piecewise_linear" Surrogate gradient / + pseudo-derivative + function + ["piecewise_linear", + "exponential", + "fast_sigmoid_derivative" + , "arctan"] +=============================== ======= =========================== ================== ========================= + +Recordables ++++++++++++ + +The following state variables evolve during simulation and can be recorded. + +================== ==== =============== ============= ======================== +**Neuron state variables and recordables** +------------------------------------------------------------------------------ +State variable Unit Math equivalent Initial value Description +================== ==== =============== ============= ======================== +``adaptation`` :math:`a_j` 0.0 Adaptation variable +``V_m`` mV :math:`v_j` -70.0 Membrane voltage +``V_th_adapt`` mV :math:`A_j` -55.0 Adapting spike threshold +================== ==== =============== ============= ======================== + +====================== ==== =============== ============= ========================================= +**E-prop state variables and recordables** +--------------------------------------------------------------------------------------------------- +State variable Unit Math equivalent Initial value Description +====================== ==== =============== ============= ========================================= +``learning_signal`` pA :math:`L_j` 0.0 Learning signal +``surrogate_gradient`` :math:`\psi_j` 0.0 Surrogate gradient / pseudo-derivative of + membrane voltage +====================== ==== =============== ============= ========================================= + +Usage ++++++ + +This model can only be used in combination with the other e-prop models +and the network architecture requires specific wiring, input, and output. +The usage is demonstrated in several +:doc:`supervised regression and classification tasks <../auto_examples/eprop_plasticity/index>` +reproducing among others the original proof-of-concept tasks in [1]_. + +References +++++++++++ + +.. [1] Bellec G, Scherr F, Subramoney F, Hajek E, Salaj D, Legenstein R, + Maass W (2020). A solution to the learning dilemma for recurrent + networks of spiking neurons. Nature Communications, 11:3625. + https://doi.org/10.1038/s41467-020-17236-y + +.. [2] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Plesser HE, + Dahmen D, Bolten M, Van Albada SJ*, Diesmann M*. Event-based + implementation of eligibility propagation (in preparation) + +.. include:: ../models/eprop_iaf.rst + :start-after: .. start_surrogate-gradient-references + :end-before: .. end_surrogate-gradient-references + +Sends ++++++ + +SpikeEvent + +Receives +++++++++ + +SpikeEvent, CurrentEvent, LearningSignalConnectionEvent, DataLoggingRequest + +See also +++++++++ + +Examples using this model ++++++++++++++++++++++++++ + +.. listexamples:: eprop_iaf_adapt + +EndUserDocs */ + +void register_eprop_iaf_adapt( const std::string& name ); + +/** + * @brief Class implementing an adaptive LIF neuron model for e-prop plasticity with additional biological features. + * + * Class implementing a current-based leaky integrate-and-fire neuron model with delta-shaped postsynaptic currents and + * threshold adaptation for e-prop plasticity according to Bellec et al. (2020) with additional biological features + * described in Korcsak-Gorzo, Stapmanns, and Espinoza Valverde et al. (in preparation). + */ +class eprop_iaf_adapt : public EpropArchivingNodeRecurrent +{ + +public: + //! Default constructor. + eprop_iaf_adapt(); + + //! Copy constructor. + eprop_iaf_adapt( const eprop_iaf_adapt& ); + + using Node::handle; + using Node::handles_test_event; + + size_t send_test_event( Node&, size_t, synindex, bool ) override; + + void handle( SpikeEvent& ) override; + void handle( CurrentEvent& ) override; + void handle( LearningSignalConnectionEvent& ) override; + void handle( DataLoggingRequest& ) override; + + size_t handles_test_event( SpikeEvent&, size_t ) override; + size_t handles_test_event( CurrentEvent&, size_t ) override; + size_t handles_test_event( LearningSignalConnectionEvent&, size_t ) override; + size_t handles_test_event( DataLoggingRequest&, size_t ) override; + + void get_status( DictionaryDatum& ) const override; + void set_status( const DictionaryDatum& ) override; + +private: + void init_buffers_() override; + void pre_run_hook() override; + + void update( Time const&, const long, const long ) override; + + void compute_gradient( const long, + const long, + double&, + double&, + double&, + double&, + double&, + double&, + const CommonSynapseProperties&, + WeightOptimizer* ) override; + + long get_shift() const override; + bool is_eprop_recurrent_node() const override; + long get_eprop_isi_trace_cutoff() const override; + + //! Pointer to member function selected for computing the surrogate gradient. + surrogate_gradient_function compute_surrogate_gradient_; + + //! Map for storing a static set of recordables. + friend class RecordablesMap< eprop_iaf_adapt >; + + //! Logger for universal data supporting the data logging request / reply mechanism. Populated with a recordables map. + friend class UniversalDataLogger< eprop_iaf_adapt >; + + //! Structure of parameters. + struct Parameters_ + { + //! Prefactor of the threshold adaptation. + double adapt_beta_; + + //! Time constant of the threshold adaptation (ms). + double adapt_tau_; + + //! Capacitance of the membrane (pF). + double C_m_; + + //! Coefficient of firing rate regularization. + double c_reg_; + + //! Leak / resting membrane potential (mV). + double E_L_; + + //! Target firing rate of rate regularization (spikes/s). + double f_target_; + + //! Width scaling of surrogate gradient / pseudo-derivative of membrane voltage. + double beta_; + + //! Height scaling of surrogate gradient / pseudo-derivative of membrane voltage. + double gamma_; + + //! Constant external input current (pA). + double I_e_; + + //! If True, the input spikes arrive at the beginning of the time step, if False at the end (determines PSC scale). + bool regular_spike_arrival_; + + //! Surrogate gradient / pseudo-derivative function of the membrane voltage ["piecewise_linear", "exponential", + //! "fast_sigmoid_derivative", "arctan"] + std::string surrogate_gradient_function_; + + //! Duration of the refractory period (ms). + double t_ref_; + + //! Time constant of the membrane (ms). + double tau_m_; + + //! Absolute lower bound of the membrane voltage relative to the leak membrane potential (mV). + double V_min_; + + //! Spike threshold voltage relative to the leak membrane potential (mV). + double V_th_; + + //! Low-pass filter of the eligibility trace. + double kappa_; + + //! Low-pass filter of the firing rate for regularization. + double kappa_reg_; + + //! Time interval from the previous spike until the cutoff of e-prop update integration between two spikes (ms). + double eprop_isi_trace_cutoff_; + + //! Default constructor. + Parameters_(); + + //! Get the parameters and their values. + void get( DictionaryDatum& ) const; + + //! Set the parameters and throw errors in case of invalid values. + double set( const DictionaryDatum&, Node* ); + }; + + //! Structure of state variables. + struct State_ + { + //! Adaptation variable. + double adapt_; + + //! Adapting spike threshold voltage. + double v_th_adapt_; + + //! Learning signal. Sum of weighted error signals coming from the readout neurons. + double learning_signal_; + + //! Number of remaining refractory steps. + int r_; + + //! Surrogate gradient / pseudo-derivative of the membrane voltage. + double surrogate_gradient_; + + //! Input current (pA). + double i_in_; + + //! Membrane voltage relative to the leak membrane potential (mV). + double v_m_; + + //! Binary spike state variable - 1.0 if the neuron has spiked in the previous time step and 0.0 otherwise. + double z_; + + //! Binary input spike state variable - 1.0 if the neuron has spiked in the previous time step and 0.0 otherwise. + double z_in_; + + //! Default constructor. + State_(); + + //! Get the state variables and their values. + void get( DictionaryDatum&, const Parameters_& ) const; + + //! Set the state variables. + void set( const DictionaryDatum&, const Parameters_&, double, Node* ); + }; + + //! Structure of buffers. + struct Buffers_ + { + //! Default constructor. + Buffers_( eprop_iaf_adapt& ); + + //! Copy constructor. + Buffers_( const Buffers_&, eprop_iaf_adapt& ); + + //! Buffer for incoming spikes. + RingBuffer spikes_; + + //! Buffer for incoming currents. + RingBuffer currents_; + + //! Logger for universal data. + UniversalDataLogger< eprop_iaf_adapt > logger_; + }; + + //! Structure of internal variables. + struct Variables_ + { + //! Propagator matrix entry for evolving the membrane voltage (mathematical symbol "alpha" in user documentation). + double P_v_m_; + + //! Propagator matrix entry for evolving the incoming spike state variables (mathematical symbol "zeta" in user + //! documentation). + double P_z_in_; + + //! Propagator matrix entry for evolving the incoming currents. + double P_i_in_; + + //! Propagator matrix entry for evolving the adaptation (mathematical symbol "rho" in user documentation). + double P_adapt_; + + //! Total refractory steps. + int RefractoryCounts_; + + //! Time steps from the previous spike until the cutoff of e-prop update integration between two spikes. + long eprop_isi_trace_cutoff_steps_; + }; + + //! Get the current value of the membrane voltage. + double + get_v_m_() const + { + return S_.v_m_ + P_.E_L_; + } + + //! Get the current value of the surrogate gradient. + double + get_surrogate_gradient_() const + { + return S_.surrogate_gradient_; + } + + //! Get the current value of the learning signal. + double + get_learning_signal_() const + { + return S_.learning_signal_; + } + + //! Get the current value of the adapting threshold. + double + get_v_th_adapt_() const + { + return S_.v_th_adapt_ + P_.E_L_; + } + + //! Get the current value of the adaptation. + double + get_adaptation_() const + { + return S_.adapt_; + } + + // the order in which the structure instances are defined is important for speed + + //! Structure of parameters. + Parameters_ P_; + + //! Structure of state variables. + State_ S_; + + //! Structure of internal variables. + Variables_ V_; + + //! Structure of buffers. + Buffers_ B_; + + //! Map storing a static set of recordables. + static RecordablesMap< eprop_iaf_adapt > recordablesMap_; +}; + +inline long +eprop_iaf_adapt::get_eprop_isi_trace_cutoff() const +{ + return V_.eprop_isi_trace_cutoff_steps_; +} + +inline size_t +eprop_iaf_adapt::send_test_event( Node& target, size_t receptor_type, synindex, bool ) +{ + SpikeEvent e; + e.set_sender( *this ); + return target.handles_test_event( e, receptor_type ); +} + +inline size_t +eprop_iaf_adapt::handles_test_event( SpikeEvent&, size_t receptor_type ) +{ + if ( receptor_type != 0 ) + { + throw UnknownReceptorType( receptor_type, get_name() ); + } + + return 0; +} + +inline size_t +eprop_iaf_adapt::handles_test_event( CurrentEvent&, size_t receptor_type ) +{ + if ( receptor_type != 0 ) + { + throw UnknownReceptorType( receptor_type, get_name() ); + } + + return 0; +} + +inline size_t +eprop_iaf_adapt::handles_test_event( LearningSignalConnectionEvent&, size_t receptor_type ) +{ + if ( receptor_type != 0 ) + { + throw UnknownReceptorType( receptor_type, get_name() ); + } + + return 0; +} + +inline size_t +eprop_iaf_adapt::handles_test_event( DataLoggingRequest& dlr, size_t receptor_type ) +{ + if ( receptor_type != 0 ) + { + throw UnknownReceptorType( receptor_type, get_name() ); + } + + return B_.logger_.connect_logging_device( dlr, recordablesMap_ ); +} + +inline void +eprop_iaf_adapt::get_status( DictionaryDatum& d ) const +{ + P_.get( d ); + S_.get( d, P_ ); + ( *d )[ names::recordables ] = recordablesMap_.get_list(); +} + +inline void +eprop_iaf_adapt::set_status( const DictionaryDatum& d ) +{ + // temporary copies in case of errors + Parameters_ ptmp = P_; + State_ stmp = S_; + + // make sure that ptmp and stmp consistent - throw BadProperty if not + const double delta_EL = ptmp.set( d, this ); + stmp.set( d, ptmp, delta_EL, this ); + + P_ = ptmp; + S_ = stmp; +} + +} // namespace nest + +#endif // EPROP_IAF_ADAPT_H diff --git a/models/eprop_iaf_adapt_bsshslm_2020.cpp b/models/eprop_iaf_adapt_bsshslm_2020.cpp index 58cb06b9e0..65e6ca927c 100644 --- a/models/eprop_iaf_adapt_bsshslm_2020.cpp +++ b/models/eprop_iaf_adapt_bsshslm_2020.cpp @@ -76,6 +76,7 @@ eprop_iaf_adapt_bsshslm_2020::Parameters_::Parameters_() , c_reg_( 0.0 ) , E_L_( -70.0 ) , f_target_( 0.01 ) + , beta_( 1.0 ) , gamma_( 0.3 ) , I_e_( 0.0 ) , regular_spike_arrival_( true ) @@ -123,6 +124,7 @@ eprop_iaf_adapt_bsshslm_2020::Parameters_::get( DictionaryDatum& d ) const def< double >( d, names::c_reg, c_reg_ ); def< double >( d, names::E_L, E_L_ ); def< double >( d, names::f_target, f_target_ ); + def< double >( d, names::beta, beta_ ); def< double >( d, names::gamma, gamma_ ); def< double >( d, names::I_e, I_e_ ); def< bool >( d, names::regular_spike_arrival, regular_spike_arrival_ ); @@ -154,6 +156,7 @@ eprop_iaf_adapt_bsshslm_2020::Parameters_::set( const DictionaryDatum& d, Node* f_target_ /= 1000.0; // convert from spikes/s to spikes/ms } + updateValueParam< double >( d, names::beta, beta_, node ); updateValueParam< double >( d, names::gamma, gamma_, node ); updateValueParam< double >( d, names::I_e, I_e_, node ); updateValueParam< bool >( d, names::regular_spike_arrival, regular_spike_arrival_, node ); @@ -178,7 +181,7 @@ eprop_iaf_adapt_bsshslm_2020::Parameters_::set( const DictionaryDatum& d, Node* if ( c_reg_ < 0 ) { - throw BadProperty( "Firing rate regularization prefactor c_reg ≥ 0 required." ); + throw BadProperty( "Firing rate regularization coefficient c_reg ≥ 0 required." ); } if ( f_target_ < 0 ) @@ -186,18 +189,6 @@ eprop_iaf_adapt_bsshslm_2020::Parameters_::set( const DictionaryDatum& d, Node* throw BadProperty( "Firing rate regularization target rate f_target ≥ 0 required." ); } - if ( gamma_ < 0.0 or 1.0 <= gamma_ ) - { - throw BadProperty( "Surrogate gradient / pseudo-derivative scaling gamma from interval [0,1) required." ); - } - - if ( surrogate_gradient_function_ != "piecewise_linear" ) - { - throw BadProperty( - "Surrogate gradient / pseudo derivate function surrogate_gradient_function from [\"piecewise_linear\"] " - "required." ); - } - if ( tau_m_ <= 0 ) { throw BadProperty( "Membrane time constant tau_m > 0 required." ); @@ -208,12 +199,6 @@ eprop_iaf_adapt_bsshslm_2020::Parameters_::set( const DictionaryDatum& d, Node* throw BadProperty( "Refractory time t_ref ≥ 0 required." ); } - if ( surrogate_gradient_function_ == "piecewise_linear" and fabs( V_th_ ) < 1e-6 ) - { - throw BadProperty( - "Relative threshold voltage V_th-E_L ≠ 0 required if surrogate_gradient_function is \"piecewise_linear\"." ); - } - if ( V_th_ < V_min_ ) { throw BadProperty( "Spike threshold voltage V_th ≥ minimal voltage V_min required." ); @@ -290,16 +275,13 @@ eprop_iaf_adapt_bsshslm_2020::pre_run_hook() V_.RefractoryCounts_ = Time( Time::ms( P_.t_ref_ ) ).get_steps(); - if ( P_.surrogate_gradient_function_ == "piecewise_linear" ) - { - compute_surrogate_gradient = &eprop_iaf_adapt_bsshslm_2020::compute_piecewise_linear_derivative; - } + compute_surrogate_gradient_ = select_surrogate_gradient( P_.surrogate_gradient_function_ ); // calculate the entries of the propagator matrix for the evolution of the state vector const double dt = Time::get_resolution().get_ms(); - V_.P_v_m_ = std::exp( -dt / P_.tau_m_ ); // called alpha in reference [1]_ + V_.P_v_m_ = std::exp( -dt / P_.tau_m_ ); V_.P_i_in_ = P_.tau_m_ / P_.C_m_ * ( 1.0 - V_.P_v_m_ ); V_.P_z_in_ = P_.regular_spike_arrival_ ? 1.0 : 1.0 - V_.P_v_m_; V_.P_adapt_ = std::exp( -dt / P_.adapt_tau_ ); @@ -336,7 +318,6 @@ eprop_iaf_adapt_bsshslm_2020::update( Time const& origin, const long from, const if ( interval_step == 0 ) { erase_used_firing_rate_reg_history(); - erase_used_update_history(); erase_used_eprop_history(); if ( with_reset ) @@ -348,6 +329,11 @@ eprop_iaf_adapt_bsshslm_2020::update( Time const& origin, const long from, const } } + if ( S_.r_ > 0 ) + { + --S_.r_; + } + S_.z_in_ = B_.spikes_.get_value( lag ); S_.v_m_ = V_.P_i_in_ * S_.i_in_ + V_.P_z_in_ * S_.z_in_ + V_.P_v_m_ * S_.v_m_; @@ -359,9 +345,8 @@ eprop_iaf_adapt_bsshslm_2020::update( Time const& origin, const long from, const S_.z_ = 0.0; - S_.surrogate_gradient_ = ( this->*compute_surrogate_gradient )(); - - write_surrogate_gradient_to_history( t, S_.surrogate_gradient_ ); + S_.surrogate_gradient_ = + ( this->*compute_surrogate_gradient_ )( S_.r_, S_.v_m_, S_.v_th_adapt_, P_.beta_, P_.gamma_ ); if ( S_.v_m_ >= S_.v_th_adapt_ and S_.r_ == 0 ) { @@ -371,13 +356,12 @@ eprop_iaf_adapt_bsshslm_2020::update( Time const& origin, const long from, const kernel().event_delivery_manager.send( *this, se, lag ); S_.z_ = 1.0; - - if ( V_.RefractoryCounts_ > 0 ) - { - S_.r_ = V_.RefractoryCounts_; - } + S_.r_ = V_.RefractoryCounts_; } + append_new_eprop_history_entry( t ); + write_surrogate_gradient_to_history( t, S_.surrogate_gradient_ ); + if ( interval_step == update_interval - 1 ) { write_firing_rate_reg_to_history( t, P_.f_target_, P_.c_reg_ ); @@ -386,32 +370,12 @@ eprop_iaf_adapt_bsshslm_2020::update( Time const& origin, const long from, const S_.learning_signal_ = get_learning_signal_from_history( t ); - if ( S_.r_ > 0 ) - { - --S_.r_; - } - S_.i_in_ = B_.currents_.get_value( lag ) + P_.I_e_; B_.logger_.record_data( t ); } } -/* ---------------------------------------------------------------- - * Surrogate gradient functions - * ---------------------------------------------------------------- */ - -double -eprop_iaf_adapt_bsshslm_2020::compute_piecewise_linear_derivative() -{ - if ( S_.r_ > 0 ) - { - return 0.0; - } - - return P_.gamma_ * std::max( 0.0, 1.0 - std::fabs( ( S_.v_m_ - S_.v_th_adapt_ ) / P_.V_th_ ) ) / P_.V_th_; -} - /* ---------------------------------------------------------------- * Event handling functions * ---------------------------------------------------------------- */ @@ -486,7 +450,7 @@ eprop_iaf_adapt_bsshslm_2020::compute_gradient( std::vector< long >& presyn_isis z_bar = V_.P_v_m_ * z_bar + V_.P_z_in_ * z; e = psi * ( z_bar - P_.adapt_beta_ * epsilon ); - epsilon = psi * z_bar + ( V_.P_adapt_ - psi * P_.adapt_beta_ ) * epsilon; + epsilon = V_.P_adapt_ * epsilon + e; e_bar = kappa * e_bar + ( 1.0 - kappa ) * e; grad += L * e_bar; sum_e += e; @@ -497,16 +461,17 @@ eprop_iaf_adapt_bsshslm_2020::compute_gradient( std::vector< long >& presyn_isis } presyn_isis.clear(); + const long update_interval = kernel().simulation_manager.get_eprop_update_interval().get_steps(); const long learning_window = kernel().simulation_manager.get_eprop_learning_window().get_steps(); + const auto firing_rate_reg = get_firing_rate_reg_history( t_previous_update + get_shift() + update_interval ); + + grad += firing_rate_reg * sum_e; + if ( average_gradient ) { grad /= learning_window; } - const long update_interval = kernel().simulation_manager.get_eprop_update_interval().get_steps(); - const auto it_reg_hist = get_firing_rate_reg_history( t_previous_update + get_shift() + update_interval ); - grad += it_reg_hist->firing_rate_reg_ * sum_e; - return grad; } diff --git a/models/eprop_iaf_adapt_bsshslm_2020.h b/models/eprop_iaf_adapt_bsshslm_2020.h index b27bc400e8..209858ac26 100644 --- a/models/eprop_iaf_adapt_bsshslm_2020.h +++ b/models/eprop_iaf_adapt_bsshslm_2020.h @@ -65,27 +65,32 @@ names and the publication year. The membrane voltage time course :math:`v_j^t` of the neuron :math:`j` is given by: .. math:: - v_j^t &= \alpha v_j^{t-1}+\sum_{i \neq j}W_{ji}^\mathrm{rec}z_i^{t-1} - + \sum_i W_{ji}^\mathrm{in}x_i^t-z_j^{t-1}v_\mathrm{th} \,, \\ - \alpha &= e^{-\frac{\Delta t}{\tau_\mathrm{m}}} \,, - -whereby :math:`W_{ji}^\mathrm{rec}` and :math:`W_{ji}^\mathrm{in}` are the recurrent and -input synaptic weights, and :math:`z_i^{t-1}` and :math:`x_i^t` are the -recurrent and input presynaptic spike state variables, respectively. + v_j^t &= \alpha v_j^{t-1} + \zeta \sum_{i \neq j} W_{ji}^\text{rec} z_i^{t-1} + + \zeta \sum_i W_{ji}^\text{in} x_i^t - z_j^{t-1} v_\text{th} \,, \\ + \alpha &= e^{ -\frac{ \Delta t }{ \tau_\text{m} } } \,, \\ + \zeta &= + \begin{cases} + 1 \\ + 1 - \alpha + \end{cases} \,, \\ + +where :math:`W_{ji}^\text{rec}` and :math:`W_{ji}^\text{in}` are the recurrent and +input synaptic weight matrices, and :math:`z_i^{t-1}` is the recurrent presynaptic +state variable, while :math:`x_i^t` represents the input at time :math:`t`. Descriptions of further parameters and variables can be found in the table below. The threshold adaptation is given by: .. math:: - A_j^t &= v_\mathrm{th} + \beta a_j^t \,, \\ - a_j^t &= \rho a_j^{t-1} + z_j^{t-1} \,, \\ - \rho &= e^{-\frac{\Delta t}{\tau_\mathrm{a}}} \,. + A_j^t &= v_\text{th} + \beta a_j^t \,, \\ + a_j^t &= \rho a_j^{t-1} + z_j^{t-1} \,, \\ + \rho &= e^{-\frac{ \Delta t }{ \tau_\text{a} }} \,. \\ The spike state variable is expressed by a Heaviside function: .. math:: - z_j^t = H\left(v_j^t-A_j^t\right) \,. + z_j^t = H \left( v_j^t - A_j^t \right) \,. \\ If the membrane voltage crosses the adaptive threshold voltage :math:`A_j^t`, a spike is emitted and the membrane voltage is reduced by :math:`v_\text{th}` in the next @@ -95,53 +100,52 @@ able to spike for an absolute refractory period :math:`t_\text{ref}`. An additional state variable and the corresponding differential equation represents a piecewise constant external current. -Furthermore, the pseudo derivative of the membrane voltage needed for e-prop -plasticity is calculated: - -.. math:: - \psi_j^t = \frac{\gamma}{v_\text{th}} \text{max} - \left(0, 1-\left| \frac{v_j^t-A_j^t}{v_\text{th}}\right| \right) \,. - -See the documentation on the ``iaf_psc_delta`` neuron model for more information -on the integration of the subthreshold dynamics. +See the documentation on the :doc:`iaf_psc_delta<../models/iaf_psc_delta/>` neuron model +for more information on the integration of the subthreshold dynamics. The change of the synaptic weight is calculated from the gradient :math:`g` of the loss :math:`E` with respect to the synaptic weight :math:`W_{ji}`: -:math:`\frac{\mathrm{d}{E}}{\mathrm{d}{W_{ij}}}=g` +:math:`\frac{ \text{d}E }{ \text{d} W_{ij} }` which depends on the presynaptic -spikes :math:`z_i^{t-1}`, the surrogate-gradient / pseudo-derivative of the postsynaptic membrane -voltage :math:`\psi_j^t` (which together form the eligibility trace -:math:`e_{ji}^t`), and the learning signal :math:`L_j^t` emitted by the readout -neurons. +spikes :math:`z_i^{t-1}`, the surrogate gradient or pseudo-derivative +of the spike state variable with respect to the postsynaptic membrane +voltage :math:`\psi_j^t` (the product of which forms the eligibility +trace :math:`e_{ji}^t`), and the learning signal :math:`L_j^t` emitted +by the readout neurons. .. math:: - \frac{\mathrm{d}E}{\mathrm{d}W_{ji}} = g &= \sum_t L_j^t \bar{e}_{ji}^t, \\ - e_{ji}^t &= \psi_j^t \left(\bar{z}_i^{t-1} - \beta \epsilon_{ji,a}^{t-1}\right)\,, \\ - \epsilon^{t-1}_{ji,\text{a}} &= \psi_j^{t-1}\bar{z}_i^{t-2} + \left( \rho - \psi_j^{t-1} \beta \right) - \epsilon^{t-2}_{ji,a}\,. \\ + \frac{ \text{d} E }{ \text{d} W_{ji} } &= \sum_t L_j^t \bar{e}_{ji}^t \,, \\ + e_{ji}^t &= \psi_j^t \left( \bar{z}_i^{t-1} - \beta \epsilon_{ji,a}^{t-1} \right) \,, \\ + \epsilon^{t-1}_{ji,\text{a}} &= \psi_j^{t-1} \bar{z}_i^{t-2} + \left( \rho - \psi_j^{t-1} \beta \right) + \epsilon^{t-2}_{ji,a} \,. \\ + +.. include:: ../models/eprop_iaf.rst + :start-after: .. start_surrogate-gradient-functions + :end-before: .. end_surrogate-gradient-functions The eligibility trace and the presynaptic spike trains are low-pass filtered -with some exponential kernels: +with the following exponential kernels: .. math:: - \bar{e}_{ji}^t&=\mathcal{F}_\kappa(e_{ji}^t) \;\text{with}\, \kappa=e^{-\frac{\Delta t}{ - \tau_\text{m,out}}}\,,\\ - \bar{z}_i^t&=\mathcal{F}_\alpha(z_i^t)\,,\\ - \mathcal{F}_\alpha(z_i^t) &= \alpha\, \mathcal{F}_\alpha(z_i^{t-1}) + z_i^t - \;\text{with}\, \mathcal{F}_\alpha(z_i^0)=z_i^0\,\,, + \bar{e}_{ji}^t &= \mathcal{F}_\kappa \left( e_{ji}^t \right) \,, \\ + \kappa &= e^{ -\frac{\Delta t }{ \tau_\text{m,out} }} \,, \\ + \bar{z}_i^t &= \mathcal{F}_\alpha(z_i^t) \,, \\ + \mathcal{F}_\alpha \left( z_i^t \right) &= \alpha \mathcal{F}_\alpha \left( z_i^{t-1} \right) + z_i^t \,, \\ + \mathcal{F}_\alpha \left( z_i^0 \right) &= z_i^0 \,, \\ -whereby :math:`\tau_\text{m,out}` is the membrane time constant of the readout neuron. +where :math:`\tau_\text{m,out}` is the membrane time constant of the readout neuron. Furthermore, a firing rate regularization mechanism keeps the average firing rate :math:`f^\text{av}_j` of the postsynaptic neuron close to a target firing rate -:math:`f^\text{target}`. The gradient :math:`g^\text{reg}` of the regularization loss :math:`E^\text{reg}` +:math:`f^\text{target}`. The gradient :math:`g_\text{reg}` of the regularization loss :math:`E_\text{reg}` with respect to the synaptic weight :math:`W_{ji}` is given by: .. math:: - \frac{\mathrm{d}E^\text{reg}}{\mathrm{d}W_{ji}} = g^\text{reg} = c_\text{reg} - \sum_t \frac{1}{Tn_\text{trial}} \left( f^\text{target}-f^\text{av}_j\right)e_{ji}^t\,, + \frac{ \text{d} E_\text{reg} }{ \text{d} W_{ji} } + = c_\text{reg} \sum_t \frac{ 1 }{ T n_\text{trial} } + \left( f^\text{target} - f^\text{av}_j \right) e_{ji}^t \,, \\ -whereby :math:`c_\text{reg}` scales the overall regularization and the average +where :math:`c_\text{reg}` is a constant scaling factor and the average is taken over the time that passed since the previous update, that is, the number of trials :math:`n_\text{trial}` times the duration of an update interval :math:`T`. @@ -166,64 +170,80 @@ The following parameters can be set in the status dictionary. ---------------------------------------------------------------------------------------------------------------- Parameter Unit Math equivalent Default Description =========================== ======= ======================= ================ =================================== -adapt_beta :math:`\beta` 1.0 Prefactor of the threshold +``adapt_beta`` :math:`\beta` 1.0 Prefactor of the threshold adaptation -adapt_tau ms :math:`\tau_\text{a}` 10.0 Time constant of the threshold +``adapt_tau`` ms :math:`\tau_\text{a}` 10.0 Time constant of the threshold adaptation -C_m pF :math:`C_\text{m}` 250.0 Capacitance of the membrane -c_reg :math:`c_\text{reg}` 0.0 Prefactor of firing rate - regularization -E_L mV :math:`E_\text{L}` -70.0 Leak / resting membrane potential -f_target Hz :math:`f^\text{target}` 10.0 Target firing rate of rate - regularization -gamma :math:`\gamma` 0.3 Scaling of surrogate gradient / - pseudo-derivative of - membrane voltage -I_e pA :math:`I_\text{e}` 0.0 Constant external input current -regular_spike_arrival Boolean True If True, the input spikes arrive at - the end of the time step, if False - at the beginning (determines PSC - scale) -surrogate_gradient_function :math:`\psi` piecewise_linear Surrogate gradient / - pseudo-derivative function - ["piecewise_linear"] -t_ref ms :math:`t_\text{ref}` 2.0 Duration of the refractory period -tau_m ms :math:`\tau_\text{m}` 10.0 Time constant of the membrane -V_min mV :math:`v_\text{min}` -1.79e+308 Absolute lower bound of the - membrane voltage -V_th mV :math:`v_\text{th}` -55.0 Spike threshold voltage +``C_m`` pF :math:`C_\text{m}` 250.0 Capacitance of the membrane +``E_L`` mV :math:`E_\text{L}` -70.0 Leak / resting membrane potential +``I_e`` pA :math:`I_\text{e}` 0.0 Constant external input current +``regular_spike_arrival`` Boolean ``True`` If ``True``, the input spikes + arrive at the end of the time step, + if ``False`` at the beginning + (determines PSC scale) +``t_ref`` ms :math:`t_\text{ref}` 2.0 Duration of the refractory period +``tau_m`` ms :math:`\tau_\text{m}` 10.0 Time constant of the membrane +``V_min`` mV :math:`v_\text{min}` negative maximum Absolute lower bound of the + value membrane voltage + representable by + a ``double`` + type in C++ +``V_th`` mV :math:`v_\text{th}` -55.0 Spike threshold voltage =========================== ======= ======================= ================ =================================== -The following state variables evolve during simulation. +=============================== ======= ======================= ================== ============================= +**E-prop parameters** +---------------------------------------------------------------------------------------------------------------- +Parameter Unit Math equivalent Default Description +=============================== ======= ======================= ================== ============================= +``c_reg`` :math:`c_\text{reg}` 0.0 Coefficient of firing rate + regularization +``f_target`` Hz :math:`f^\text{target}` 10.0 Target firing rate of rate + regularization +``beta`` :math:`\beta` 1.0 Width scaling of surrogate + gradient / pseudo-derivative + of membrane voltage +``gamma`` :math:`\gamma` 0.3 Height scaling of surrogate + gradient / pseudo-derivative + of membrane voltage +``surrogate_gradient_function`` :math:`\psi` "piecewise_linear" Surrogate gradient / + pseudo-derivative function + ["piecewise_linear", + "exponential", + "fast_sigmoid_derivative", + "arctan"] +=============================== ======= ======================= ================== ============================= + +Recordables ++++++++++++ + +The following state variables evolve during simulation and can be recorded. ================== ==== =============== ============= ======================== **Neuron state variables and recordables** ------------------------------------------------------------------------------ State variable Unit Math equivalent Initial value Description ================== ==== =============== ============= ======================== -adaptation :math:`a_j` 0.0 Adaptation variable -learning_signal :math:`L_j` 0.0 Learning signal -surrogate_gradient :math:`\psi_j` 0.0 Surrogate gradient -V_m mV :math:`v_j` -70.0 Membrane voltage -V_th_adapt mV :math:`A_j` -55.0 Adapting spike threshold +``adaptation`` :math:`a_j` 0.0 Adaptation variable +``V_m`` mV :math:`v_j` -70.0 Membrane voltage +``V_th_adapt`` mV :math:`A_j` -55.0 Adapting spike threshold ================== ==== =============== ============= ======================== -Recordables -+++++++++++ - -The following variables can be recorded: - - - adaptation variable ``adaptation`` - - adapting spike threshold ``V_th_adapt`` - - learning signal ``learning_signal`` - - membrane potential ``V_m`` - - surrogate gradient ``surrogate_gradient`` +====================== ==== =============== ============= ========================================= +**E-prop state variables and recordables** +--------------------------------------------------------------------------------------------------- +State variable Unit Math equivalent Initial value Description +====================== ==== =============== ============= ========================================= +``learning_signal`` pA :math:`L_j` 0.0 Learning signal +``surrogate_gradient`` :math:`\psi_j` 0.0 Surrogate gradient / pseudo-derivative of + membrane voltage +====================== ==== =============== ============= ========================================= Usage +++++ -This model can only be used in combination with the other e-prop models, -whereby the network architecture requires specific wiring, input, and output. +This model can only be used in combination with the other e-prop models +and the network architecture requires specific wiring, input, and output. The usage is demonstrated in several :doc:`supervised regression and classification tasks <../auto_examples/eprop_plasticity/index>` reproducing among others the original proof-of-concept tasks in [1]_. @@ -235,12 +255,17 @@ References Maass W (2020). A solution to the learning dilemma for recurrent networks of spiking neurons. Nature Communications, 11:3625. https://doi.org/10.1038/s41467-020-17236-y -.. [2] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Dahmen D, - van Albada SJ, Bolten M, Diesmann M. Event-based implementation of - eligibility propagation (in preparation) + +.. [2] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Plesser HE, + Dahmen D, Bolten M, Van Albada SJ*, Diesmann M*. Event-based + implementation of eligibility propagation (in preparation) + +.. include:: ../models/eprop_iaf.rst + :start-after: .. start_surrogate-gradient-references + :end-before: .. end_surrogate-gradient-references Sends -++++++++ ++++++ SpikeEvent @@ -253,7 +278,7 @@ See also ++++++++ Examples using this model -++++++++++++++++++++++++++ ++++++++++++++++++++++++++ .. listexamples:: eprop_iaf_adapt_bsshslm_2020 @@ -262,8 +287,10 @@ EndUserDocs */ void register_eprop_iaf_adapt_bsshslm_2020( const std::string& name ); /** + * @brief Class implementing an adaptive LIF neuron model for e-prop plasticity. + * * Class implementing a current-based leaky integrate-and-fire neuron model with delta-shaped postsynaptic currents and - * threshold adaptation for e-prop plasticity according to Bellec et al (2020). + * threshold adaptation for e-prop plasticity according to Bellec et al. (2020). */ class eprop_iaf_adapt_bsshslm_2020 : public EpropArchivingNodeRecurrent { @@ -293,26 +320,19 @@ class eprop_iaf_adapt_bsshslm_2020 : public EpropArchivingNodeRecurrent void get_status( DictionaryDatum& ) const override; void set_status( const DictionaryDatum& ) override; - double compute_gradient( std::vector< long >& presyn_isis, - const long t_previous_update, - const long t_previous_trigger_spike, - const double kappa, - const bool average_gradient ) override; - +private: + void init_buffers_() override; void pre_run_hook() override; - long get_shift() const override; - bool is_eprop_recurrent_node() const override; + void update( Time const&, const long, const long ) override; -protected: - void init_buffers_() override; + double compute_gradient( std::vector< long >&, const long, const long, const double, const bool ) override; -private: - //! Compute the piecewise linear surrogate gradient. - double compute_piecewise_linear_derivative(); + long get_shift() const override; + bool is_eprop_recurrent_node() const override; - //! Compute the surrogate gradient. - double ( eprop_iaf_adapt_bsshslm_2020::*compute_surrogate_gradient )(); + //! Pointer to member function selected for computing the surrogate gradient. + surrogate_gradient_function compute_surrogate_gradient_; //! Map for storing a static set of recordables. friend class RecordablesMap< eprop_iaf_adapt_bsshslm_2020 >; @@ -332,7 +352,7 @@ class eprop_iaf_adapt_bsshslm_2020 : public EpropArchivingNodeRecurrent //! Capacitance of the membrane (pF). double C_m_; - //! Prefactor of firing rate regularization. + //! Coefficient of firing rate regularization. double c_reg_; //! Leak / resting membrane potential (mV). @@ -341,7 +361,10 @@ class eprop_iaf_adapt_bsshslm_2020 : public EpropArchivingNodeRecurrent //! Target firing rate of rate regularization (spikes/s). double f_target_; - //! Scaling of surrogate-gradient / pseudo-derivative of membrane voltage. + //! Width scaling of surrogate gradient / pseudo-derivative of membrane voltage. + double beta_; + + //! Height scaling of surrogate gradient / pseudo-derivative of membrane voltage. double gamma_; //! Constant external input current (pA). @@ -350,7 +373,8 @@ class eprop_iaf_adapt_bsshslm_2020 : public EpropArchivingNodeRecurrent //! If True, the input spikes arrive at the beginning of the time step, if False at the end (determines PSC scale). bool regular_spike_arrival_; - //! Surrogate gradient / pseudo-derivative function ["piecewise_linear"]. + //! Surrogate gradient / pseudo-derivative function of the membrane voltage ["piecewise_linear", "exponential", + //! "fast_sigmoid_derivative", "arctan"] std::string surrogate_gradient_function_; //! Duration of the refractory period (ms). @@ -399,10 +423,10 @@ class eprop_iaf_adapt_bsshslm_2020 : public EpropArchivingNodeRecurrent //! Membrane voltage relative to the leak membrane potential (mV). double v_m_; - //! Binary spike variable - 1.0 if the neuron has spiked in the previous time step and 0.0 otherwise. + //! Binary spike state variable - 1.0 if the neuron has spiked in the previous time step and 0.0 otherwise. double z_; - //! Binary input spike variables - 1.0 if the neuron has spiked in the previous time step and 0.0 otherwise. + //! Binary input spike state variable - 1.0 if the neuron has spiked in the previous time step and 0.0 otherwise. double z_in_; //! Default constructor. @@ -434,19 +458,20 @@ class eprop_iaf_adapt_bsshslm_2020 : public EpropArchivingNodeRecurrent UniversalDataLogger< eprop_iaf_adapt_bsshslm_2020 > logger_; }; - //! Structure of general variables. + //! Structure of internal variables. struct Variables_ { - //! Propagator matrix entry for evolving the membrane voltage. + //! Propagator matrix entry for evolving the membrane voltage (mathematical symbol "alpha" in user documentation). double P_v_m_; - //! Propagator matrix entry for evolving the incoming spike variables. + //! Propagator matrix entry for evolving the incoming spike state variables (mathematical symbol "zeta" in user + //! documentation). double P_z_in_; //! Propagator matrix entry for evolving the incoming currents. double P_i_in_; - //! Propagator matrix entry for evolving the adaptation. + //! Propagator matrix entry for evolving the adaptation (mathematical symbol "rho" in user documentation). double P_adapt_; //! Total refractory steps. @@ -490,16 +515,16 @@ class eprop_iaf_adapt_bsshslm_2020 : public EpropArchivingNodeRecurrent // the order in which the structure instances are defined is important for speed - //!< Structure of parameters. + //! Structure of parameters. Parameters_ P_; - //!< Structure of state variables. + //! Structure of state variables. State_ S_; - //!< Structure of general variables. + //! Structure of internal variables. Variables_ V_; - //!< Structure of buffers. + //! Structure of buffers. Buffers_ B_; //! Map storing a static set of recordables. diff --git a/models/eprop_iaf_bsshslm_2020.cpp b/models/eprop_iaf_bsshslm_2020.cpp index 108ea1e71a..3b8cde4cbc 100644 --- a/models/eprop_iaf_bsshslm_2020.cpp +++ b/models/eprop_iaf_bsshslm_2020.cpp @@ -72,6 +72,7 @@ eprop_iaf_bsshslm_2020::Parameters_::Parameters_() , c_reg_( 0.0 ) , E_L_( -70.0 ) , f_target_( 0.01 ) + , beta_( 1.0 ) , gamma_( 0.3 ) , I_e_( 0.0 ) , regular_spike_arrival_( true ) @@ -115,6 +116,7 @@ eprop_iaf_bsshslm_2020::Parameters_::get( DictionaryDatum& d ) const def< double >( d, names::c_reg, c_reg_ ); def< double >( d, names::E_L, E_L_ ); def< double >( d, names::f_target, f_target_ ); + def< double >( d, names::beta, beta_ ); def< double >( d, names::gamma, gamma_ ); def< double >( d, names::I_e, I_e_ ); def< bool >( d, names::regular_spike_arrival, regular_spike_arrival_ ); @@ -144,6 +146,7 @@ eprop_iaf_bsshslm_2020::Parameters_::set( const DictionaryDatum& d, Node* node ) f_target_ /= 1000.0; // convert from spikes/s to spikes/ms } + updateValueParam< double >( d, names::beta, beta_, node ); updateValueParam< double >( d, names::gamma, gamma_, node ); updateValueParam< double >( d, names::I_e, I_e_, node ); updateValueParam< bool >( d, names::regular_spike_arrival, regular_spike_arrival_, node ); @@ -158,7 +161,7 @@ eprop_iaf_bsshslm_2020::Parameters_::set( const DictionaryDatum& d, Node* node ) if ( c_reg_ < 0 ) { - throw BadProperty( "Firing rate regularization prefactor c_reg ≥ 0 required." ); + throw BadProperty( "Firing rate regularization coefficient c_reg ≥ 0 required." ); } if ( f_target_ < 0 ) @@ -166,18 +169,6 @@ eprop_iaf_bsshslm_2020::Parameters_::set( const DictionaryDatum& d, Node* node ) throw BadProperty( "Firing rate regularization target rate f_target ≥ 0 required." ); } - if ( gamma_ < 0.0 or 1.0 <= gamma_ ) - { - throw BadProperty( "Surrogate gradient / pseudo-derivative scaling gamma from interval [0,1) required." ); - } - - if ( surrogate_gradient_function_ != "piecewise_linear" ) - { - throw BadProperty( - "Surrogate gradient / pseudo derivate function surrogate_gradient_function from [\"piecewise_linear\"] " - "required." ); - } - if ( tau_m_ <= 0 ) { throw BadProperty( "Membrane time constant tau_m > 0 required." ); @@ -188,12 +179,6 @@ eprop_iaf_bsshslm_2020::Parameters_::set( const DictionaryDatum& d, Node* node ) throw BadProperty( "Refractory time t_ref ≥ 0 required." ); } - if ( surrogate_gradient_function_ == "piecewise_linear" and fabs( V_th_ ) < 1e-6 ) - { - throw BadProperty( - "Relative threshold voltage V_th-E_L ≠ 0 required if surrogate_gradient_function is \"piecewise_linear\"." ); - } - if ( V_th_ < V_min_ ) { throw BadProperty( "Spike threshold voltage V_th ≥ minimal voltage V_min required." ); @@ -256,16 +241,13 @@ eprop_iaf_bsshslm_2020::pre_run_hook() V_.RefractoryCounts_ = Time( Time::ms( P_.t_ref_ ) ).get_steps(); - if ( P_.surrogate_gradient_function_ == "piecewise_linear" ) - { - compute_surrogate_gradient = &eprop_iaf_bsshslm_2020::compute_piecewise_linear_derivative; - } + compute_surrogate_gradient_ = select_surrogate_gradient( P_.surrogate_gradient_function_ ); // calculate the entries of the propagator matrix for the evolution of the state vector const double dt = Time::get_resolution().get_ms(); - V_.P_v_m_ = std::exp( -dt / P_.tau_m_ ); // called alpha in reference [1] + V_.P_v_m_ = std::exp( -dt / P_.tau_m_ ); V_.P_i_in_ = P_.tau_m_ / P_.C_m_ * ( 1.0 - V_.P_v_m_ ); V_.P_z_in_ = P_.regular_spike_arrival_ ? 1.0 : 1.0 - V_.P_v_m_; } @@ -301,7 +283,6 @@ eprop_iaf_bsshslm_2020::update( Time const& origin, const long from, const long if ( interval_step == 0 ) { erase_used_firing_rate_reg_history(); - erase_used_update_history(); erase_used_eprop_history(); if ( with_reset ) @@ -312,6 +293,11 @@ eprop_iaf_bsshslm_2020::update( Time const& origin, const long from, const long } } + if ( S_.r_ > 0 ) + { + --S_.r_; + } + S_.z_in_ = B_.spikes_.get_value( lag ); S_.v_m_ = V_.P_i_in_ * S_.i_in_ + V_.P_z_in_ * S_.z_in_ + V_.P_v_m_ * S_.v_m_; @@ -320,9 +306,7 @@ eprop_iaf_bsshslm_2020::update( Time const& origin, const long from, const long S_.z_ = 0.0; - S_.surrogate_gradient_ = ( this->*compute_surrogate_gradient )(); - - write_surrogate_gradient_to_history( t, S_.surrogate_gradient_ ); + S_.surrogate_gradient_ = ( this->*compute_surrogate_gradient_ )( S_.r_, S_.v_m_, P_.V_th_, P_.beta_, P_.gamma_ ); if ( S_.v_m_ >= P_.V_th_ and S_.r_ == 0 ) { @@ -332,13 +316,12 @@ eprop_iaf_bsshslm_2020::update( Time const& origin, const long from, const long kernel().event_delivery_manager.send( *this, se, lag ); S_.z_ = 1.0; - - if ( V_.RefractoryCounts_ > 0 ) - { - S_.r_ = V_.RefractoryCounts_; - } + S_.r_ = V_.RefractoryCounts_; } + append_new_eprop_history_entry( t ); + write_surrogate_gradient_to_history( t, S_.surrogate_gradient_ ); + if ( interval_step == update_interval - 1 ) { write_firing_rate_reg_to_history( t, P_.f_target_, P_.c_reg_ ); @@ -347,32 +330,12 @@ eprop_iaf_bsshslm_2020::update( Time const& origin, const long from, const long S_.learning_signal_ = get_learning_signal_from_history( t ); - if ( S_.r_ > 0 ) - { - --S_.r_; - } - S_.i_in_ = B_.currents_.get_value( lag ) + P_.I_e_; B_.logger_.record_data( t ); } } -/* ---------------------------------------------------------------- - * Surrogate gradient functions - * ---------------------------------------------------------------- */ - -double -eprop_iaf_bsshslm_2020::compute_piecewise_linear_derivative() -{ - if ( S_.r_ > 0 ) - { - return 0.0; - } - - return P_.gamma_ * std::max( 0.0, 1.0 - std::fabs( ( S_.v_m_ - P_.V_th_ ) / P_.V_th_ ) ) / P_.V_th_; -} - /* ---------------------------------------------------------------- * Event handling functions * ---------------------------------------------------------------- */ @@ -456,16 +419,17 @@ eprop_iaf_bsshslm_2020::compute_gradient( std::vector< long >& presyn_isis, } presyn_isis.clear(); + const long update_interval = kernel().simulation_manager.get_eprop_update_interval().get_steps(); const long learning_window = kernel().simulation_manager.get_eprop_learning_window().get_steps(); + const auto firing_rate_reg = get_firing_rate_reg_history( t_previous_update + get_shift() + update_interval ); + + grad += firing_rate_reg * sum_e; + if ( average_gradient ) { grad /= learning_window; } - const long update_interval = kernel().simulation_manager.get_eprop_update_interval().get_steps(); - const auto it_reg_hist = get_firing_rate_reg_history( t_previous_update + get_shift() + update_interval ); - grad += it_reg_hist->firing_rate_reg_ * sum_e; - return grad; } diff --git a/models/eprop_iaf_bsshslm_2020.h b/models/eprop_iaf_bsshslm_2020.h index 2a7f2d96b1..d4214d37e0 100644 --- a/models/eprop_iaf_bsshslm_2020.h +++ b/models/eprop_iaf_bsshslm_2020.h @@ -65,20 +65,25 @@ names and the publication year. The membrane voltage time course :math:`v_j^t` of the neuron :math:`j` is given by: .. math:: - v_j^t &= \alpha v_j^{t-1}+\sum_{i \neq j}W_{ji}^\mathrm{rec}z_i^{t-1} - + \sum_i W_{ji}^\mathrm{in}x_i^t-z_j^{t-1}v_\mathrm{th} \,, \\ - \alpha &= e^{-\frac{\Delta t}{\tau_\mathrm{m}}} \,, - -whereby :math:`W_{ji}^\mathrm{rec}` and :math:`W_{ji}^\mathrm{in}` are the recurrent and -input synaptic weights, and :math:`z_i^{t-1}` and :math:`x_i^t` are the -recurrent and input presynaptic spike state variables, respectively. + v_j^t &= \alpha v_j^{t-1} + \zeta \sum_{i \neq j} W_{ji}^\text{rec} z_i^{t-1} + + \zeta \sum_i W_{ji}^\text{in} x_i^t - z_j^{t-1} v_\text{th} \,, \\ + \alpha &= e^{ -\frac{ \Delta t }{ \tau_\text{m} } } \,, \\ + \zeta &= + \begin{cases} + 1 \\ + 1 - \alpha + \end{cases} \,, \\ + +where :math:`W_{ji}^\text{rec}` and :math:`W_{ji}^\text{in}` are the recurrent and +input synaptic weight matrices, and :math:`z_i^{t-1}` is the recurrent presynaptic +state variable, while :math:`x_i^t` represents the input at time :math:`t`. Descriptions of further parameters and variables can be found in the table below. The spike state variable is expressed by a Heaviside function: .. math:: - z_j^t = H\left(v_j^t-v_\mathrm{th}\right) \,. + z_j^t = H \left( v_j^t - v_\text{th} \right) \,. \\ If the membrane voltage crosses the threshold voltage :math:`v_\text{th}`, a spike is emitted and the membrane voltage is reduced by :math:`v_\text{th}` in the next @@ -88,51 +93,50 @@ able to spike for an absolute refractory period :math:`t_\text{ref}`. An additional state variable and the corresponding differential equation represents a piecewise constant external current. -Furthermore, the pseudo derivative of the membrane voltage needed for e-prop -plasticity is calculated: - -.. math:: - \psi_j^t = \frac{\gamma}{v_\text{th}} \text{max} - \left(0, 1-\left| \frac{v_j^t-v_\mathrm{th}}{v_\text{th}}\right| \right) \,. - -See the documentation on the ``iaf_psc_delta`` neuron model for more information -on the integration of the subthreshold dynamics. +See the documentation on the :doc:`iaf_psc_delta<../models/iaf_psc_delta/>` neuron model +for more information on the integration of the subthreshold dynamics. The change of the synaptic weight is calculated from the gradient :math:`g` of the loss :math:`E` with respect to the synaptic weight :math:`W_{ji}`: -:math:`\frac{\mathrm{d}{E}}{\mathrm{d}{W_{ij}}}=g` +:math:`\frac{ \text{d}E }{ \text{d} W_{ij} }` which depends on the presynaptic -spikes :math:`z_i^{t-1}`, the surrogate-gradient / pseudo-derivative of the postsynaptic membrane -voltage :math:`\psi_j^t` (which together form the eligibility trace -:math:`e_{ji}^t`), and the learning signal :math:`L_j^t` emitted by the readout -neurons. +spikes :math:`z_i^{t-1}`, the surrogate gradient or pseudo-derivative +of the spike state variable with respect to the postsynaptic membrane +voltage :math:`\psi_j^t` (the product of which forms the eligibility +trace :math:`e_{ji}^t`), and the learning signal :math:`L_j^t` emitted +by the readout neurons. .. math:: - \frac{\mathrm{d}E}{\mathrm{d}W_{ji}} = g &= \sum_t L_j^t \bar{e}_{ji}^t, \\ - e_{ji}^t &= \psi^t_j \bar{z}_i^{t-1}\,, \\ + \frac{ \text{d} E }{ \text{d} W_{ji} } &= \sum_t L_j^t \bar{e}_{ji}^t \,, \\ + e_{ji}^t &= \psi^t_j \bar{z}_i^{t-1} \,, \\ + +.. include:: ../models/eprop_iaf.rst + :start-after: .. start_surrogate-gradient-functions + :end-before: .. end_surrogate-gradient-functions The eligibility trace and the presynaptic spike trains are low-pass filtered -with some exponential kernels: +with the following exponential kernels: .. math:: - \bar{e}_{ji}^t &= \mathcal{F}_\kappa(e_{ji}^t) \;\text{with}\, \kappa=e^{-\frac{\Delta t}{ - \tau_\text{m,out}}}\,,\\ - \bar{z}_i^t&=\mathcal{F}_\alpha(z_i^t)\,,\\ - \mathcal{F}_\alpha(z_i^t) &= \alpha\, \mathcal{F}_\alpha(z_i^{t-1}) + z_i^t - \;\text{with}\, \mathcal{F}_\alpha(z_i^0)=z_i^0\,, + \bar{e}_{ji}^t &= \mathcal{F}_\kappa \left( e_{ji}^t \right) \,, \\ + \kappa &= e^{ -\frac{\Delta t }{ \tau_\text{m,out} }} \,, \\ + \bar{z}_i^t &= \mathcal{F}_\alpha(z_i^t) \,, \\ + \mathcal{F}_\alpha \left( z_i^t \right) &= \alpha \mathcal{F}_\alpha \left( z_i^{t-1} \right) + z_i^t \,, \\ + \mathcal{F}_\alpha \left( z_i^0 \right) &= z_i^0 \,, \\ -whereby :math:`\tau_\text{m,out}` is the membrane time constant of the readout neuron. +where :math:`\tau_\text{m,out}` is the membrane time constant of the readout neuron. Furthermore, a firing rate regularization mechanism keeps the average firing rate :math:`f^\text{av}_j` of the postsynaptic neuron close to a target firing rate -:math:`f^\text{target}`. The gradient :math:`g^\text{reg}` of the regularization loss :math:`E^\text{reg}` +:math:`f^\text{target}`. The gradient :math:`g_\text{reg}` of the regularization loss :math:`E_\text{reg}` with respect to the synaptic weight :math:`W_{ji}` is given by: .. math:: - \frac{\mathrm{d}E^\text{reg}}{\mathrm{d}W_{ji}} = g^\text{reg} = c_\text{reg} - \sum_t \frac{1}{Tn_\text{trial}} \left( f^\text{target}-f^\text{av}_j\right)e_{ji}^t\,, + \frac{ \text{d} E_\text{reg} }{ \text{d} W_{ji} } + = c_\text{reg} \sum_t \frac{ 1 }{ T n_\text{trial} } + \left( f^\text{target} - f^\text{av}_j \right) e_{ji}^t \,, \\ -whereby :math:`c_\text{reg}` scales the overall regularization and the average +where :math:`c_\text{reg}` is a constant scaling factor and the average is taken over the time that passed since the previous update, that is, the number of trials :math:`n_\text{trial}` times the duration of an update interval :math:`T`. @@ -157,56 +161,74 @@ The following parameters can be set in the status dictionary. ---------------------------------------------------------------------------------------------------------------- Parameter Unit Math equivalent Default Description =========================== ======= ======================= ================ =================================== -C_m pF :math:`C_\text{m}` 250.0 Capacitance of the membrane -c_reg :math:`c_\text{reg}` 0.0 Prefactor of firing rate - regularization -E_L mV :math:`E_\text{L}` -70.0 Leak / resting membrane potential -f_target Hz :math:`f^\text{target}` 10.0 Target firing rate of rate - regularization -gamma :math:`\gamma` 0.3 Scaling of surrogate gradient / - pseudo-derivative of membrane - voltage -I_e pA :math:`I_\text{e}` 0.0 Constant external input current -regular_spike_arrival Boolean True If True, the input spikes arrive at - the end of the time step, if - False at the beginning (determines - PSC scale) -surrogate_gradient_function :math:`\psi` piecewise_linear Surrogate gradient / - pseudo-derivative function - ["piecewise_linear"] -t_ref ms :math:`t_\text{ref}` 2.0 Duration of the refractory period -tau_m ms :math:`\tau_\text{m}` 10.0 Time constant of the membrane -V_min mV :math:`v_\text{min}` -1.79e+308 Absolute lower bound of the - membrane voltage -V_th mV :math:`v_\text{th}` -55.0 Spike threshold voltage +``C_m`` pF :math:`C_\text{m}` 250.0 Capacitance of the membrane +``E_L`` mV :math:`E_\text{L}` -70.0 Leak / resting membrane potential +``I_e`` pA :math:`I_\text{e}` 0.0 Constant external input current +``regular_spike_arrival`` Boolean ``True`` If ``True``, the input spikes + arrive at the end of the time step, + if ``False`` at the beginning + (determines PSC scale) +``t_ref`` ms :math:`t_\text{ref}` 2.0 Duration of the refractory period +``tau_m`` ms :math:`\tau_\text{m}` 10.0 Time constant of the membrane +``V_min`` mV :math:`v_\text{min}` negative maximum Absolute lower bound of the + value membrane voltage + representable by + a ``double`` + type in C++ +``V_th`` mV :math:`v_\text{th}` -55.0 Spike threshold voltage =========================== ======= ======================= ================ =================================== -The following state variables evolve during simulation. - -================== ==== =============== ============= ========================================================== -**Neuron state variables and recordables** +=============================== ==== ======================= ================== ================================ +**E-prop parameters** ---------------------------------------------------------------------------------------------------------------- -State variable Unit Math equivalent Initial value Description -================== ==== =============== ============= ========================================================== -learning_signal pA :math:`L_j` 0.0 Learning signal -surrogate_gradient :math:`\psi_j` 0.0 Surrogate gradient / pseudo-derivative of membrane voltage -V_m mV :math:`v_j` -70.0 Membrane voltage -================== ==== =============== ============= ========================================================== +Parameter Unit Math equivalent Default Description +=============================== ==== ======================= ================== ================================ +``c_reg`` :math:`c_\text{reg}` 0.0 Coefficient of firing rate + regularization +``f_target`` Hz :math:`f^\text{target}` 10.0 Target firing rate of rate + regularization +``beta`` :math:`\beta` 1.0 Width scaling of surrogate + gradient / pseudo-derivative of + membrane voltage +``gamma`` :math:`\gamma` 0.3 Height scaling of surrogate + gradient / pseudo-derivative of + membrane voltage +``surrogate_gradient_function`` :math:`\psi` "piecewise_linear" Surrogate gradient / + pseudo-derivative function + ["piecewise_linear", + "exponential", + "fast_sigmoid_derivative", + "arctan"] +=============================== ==== ======================= ================== ================================ Recordables +++++++++++ -The following variables can be recorded: +The following state variables evolve during simulation and can be recorded. - - learning signal ``learning_signal`` - - membrane potential ``V_m`` - - surrogate gradient ``surrogate_gradient`` +================== ==== =============== ============= ================ +**Neuron state variables and recordables** +---------------------------------------------------------------------- +State variable Unit Math equivalent Initial value Description +================== ==== =============== ============= ================ +``V_m`` mV :math:`v_j` -70.0 Membrane voltage +================== ==== =============== ============= ================ + +====================== ==== =============== ============= ========================================= +**E-prop state variables and recordables** +--------------------------------------------------------------------------------------------------- +State variable Unit Math equivalent Initial value Description +====================== ==== =============== ============= ========================================= +``learning_signal`` pA :math:`L_j` 0.0 Learning signal +``surrogate_gradient`` :math:`\psi_j` 0.0 Surrogate gradient / pseudo-derivative of + membrane voltage +====================== ==== =============== ============= ========================================= Usage +++++ -This model can only be used in combination with the other e-prop models, -whereby the network architecture requires specific wiring, input, and output. +This model can only be used in combination with the other e-prop models +and the network architecture requires specific wiring, input, and output. The usage is demonstrated in several :doc:`supervised regression and classification tasks <../auto_examples/eprop_plasticity/index>` reproducing among others the original proof-of-concept tasks in [1]_. @@ -218,12 +240,17 @@ References Maass W (2020). A solution to the learning dilemma for recurrent networks of spiking neurons. Nature Communications, 11:3625. https://doi.org/10.1038/s41467-020-17236-y -.. [2] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Dahmen D, - van Albada SJ, Bolten M, Diesmann M. Event-based implementation of - eligibility propagation (in preparation) + +.. [2] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Plesser HE, + Dahmen D, Bolten M, Van Albada SJ*, Diesmann M*. Event-based + implementation of eligibility propagation (in preparation) + +.. include:: ../models/eprop_iaf.rst + :start-after: .. start_surrogate-gradient-references + :end-before: .. end_surrogate-gradient-references Sends -++++++++ ++++++ SpikeEvent @@ -236,7 +263,7 @@ See also ++++++++ Examples using this model -++++++++++++++++++++++++++ ++++++++++++++++++++++++++ .. listexamples:: eprop_iaf_bsshslm_2020 @@ -245,8 +272,10 @@ EndUserDocs */ void register_eprop_iaf_bsshslm_2020( const std::string& name ); /** + * @brief Class implementing a LIF neuron model for e-prop plasticity. + * * Class implementing a current-based leaky integrate-and-fire neuron model with delta-shaped postsynaptic currents for - * e-prop plasticity according to Bellec et al (2020). + * e-prop plasticity according to Bellec et al. (2020). */ class eprop_iaf_bsshslm_2020 : public EpropArchivingNodeRecurrent { @@ -276,26 +305,19 @@ class eprop_iaf_bsshslm_2020 : public EpropArchivingNodeRecurrent void get_status( DictionaryDatum& ) const override; void set_status( const DictionaryDatum& ) override; - double compute_gradient( std::vector< long >& presyn_isis, - const long t_previous_update, - const long t_previous_trigger_spike, - const double kappa, - const bool average_gradient ) override; - +private: + void init_buffers_() override; void pre_run_hook() override; - long get_shift() const override; - bool is_eprop_recurrent_node() const override; + void update( Time const&, const long, const long ) override; -protected: - void init_buffers_() override; + double compute_gradient( std::vector< long >&, const long, const long, const double, const bool ) override; -private: - //! Compute the piecewise linear surrogate gradient. - double compute_piecewise_linear_derivative(); + long get_shift() const override; + bool is_eprop_recurrent_node() const override; - //! Compute the surrogate gradient. - double ( eprop_iaf_bsshslm_2020::*compute_surrogate_gradient )(); + //! Pointer to member function selected for computing the surrogate gradient. + surrogate_gradient_function compute_surrogate_gradient_; //! Map for storing a static set of recordables. friend class RecordablesMap< eprop_iaf_bsshslm_2020 >; @@ -309,7 +331,7 @@ class eprop_iaf_bsshslm_2020 : public EpropArchivingNodeRecurrent //! Capacitance of the membrane (pF). double C_m_; - //! Prefactor of firing rate regularization. + //! Coefficient of firing rate regularization. double c_reg_; //! Leak / resting membrane potential (mV). @@ -318,7 +340,10 @@ class eprop_iaf_bsshslm_2020 : public EpropArchivingNodeRecurrent //! Target firing rate of rate regularization (spikes/s). double f_target_; - //! Scaling of surrogate-gradient / pseudo-derivative of membrane voltage. + //! Width scaling of surrogate gradient / pseudo-derivative of membrane voltage. + double beta_; + + //! Height scaling of surrogate gradient / pseudo-derivative of membrane voltage. double gamma_; //! Constant external input current (pA). @@ -327,7 +352,8 @@ class eprop_iaf_bsshslm_2020 : public EpropArchivingNodeRecurrent //! If True, the input spikes arrive at the beginning of the time step, if False at the end (determines PSC scale). bool regular_spike_arrival_; - //! Surrogate gradient / pseudo-derivative function ["piecewise_linear"]. + //! Surrogate gradient / pseudo-derivative function of the membrane voltage ["piecewise_linear", "exponential", + //! "fast_sigmoid_derivative", "arctan"] std::string surrogate_gradient_function_; //! Duration of the refractory period (ms). @@ -370,10 +396,10 @@ class eprop_iaf_bsshslm_2020 : public EpropArchivingNodeRecurrent //! Membrane voltage relative to the leak membrane potential (mV). double v_m_; - //! Binary spike variable - 1.0 if the neuron has spiked in the previous time step and 0.0 otherwise. + //! Binary spike state variable - 1.0 if the neuron has spiked in the previous time step and 0.0 otherwise. double z_; - //! Binary input spike variables - 1.0 if the neuron has spiked in the previous time step and 0.0 otherwise. + //! Binary input spike state variable - 1.0 if the neuron has spiked in the previous time step and 0.0 otherwise. double z_in_; //! Default constructor. @@ -405,13 +431,14 @@ class eprop_iaf_bsshslm_2020 : public EpropArchivingNodeRecurrent UniversalDataLogger< eprop_iaf_bsshslm_2020 > logger_; }; - //! Structure of general variables. + //! Structure of internal variables. struct Variables_ { - //! Propagator matrix entry for evolving the membrane voltage. + //! Propagator matrix entry for evolving the membrane voltage (mathematical symbol "alpha" in user documentation). double P_v_m_; - //! Propagator matrix entry for evolving the incoming spike variables. + //! Propagator matrix entry for evolving the incoming spike state variables (mathematical symbol "zeta" in user + //! documentation). double P_z_in_; //! Propagator matrix entry for evolving the incoming currents. @@ -444,16 +471,16 @@ class eprop_iaf_bsshslm_2020 : public EpropArchivingNodeRecurrent // the order in which the structure instances are defined is important for speed - //!< Structure of parameters. + //! Structure of parameters. Parameters_ P_; - //!< Structure of state variables. + //! Structure of state variables. State_ S_; - //!< Structure of general variables. + //! Structure of internal variables. Variables_ V_; - //!< Structure of buffers. + //! Structure of buffers. Buffers_ B_; //! Map storing a static set of recordables. diff --git a/models/eprop_iaf_psc_delta.cpp b/models/eprop_iaf_psc_delta.cpp new file mode 100644 index 0000000000..a6d5035e3e --- /dev/null +++ b/models/eprop_iaf_psc_delta.cpp @@ -0,0 +1,477 @@ +/* + * eprop_iaf_psc_delta.cpp + * + * This file is part of NEST. + * + * Copyright (C) 2004 The NEST Initiative + * + * NEST is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * NEST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with NEST. If not, see . + * + */ + +// nest models +#include "eprop_iaf_psc_delta.h" + +// C++ +#include + +// libnestutil +#include "dict_util.h" +#include "numerics.h" + +// nestkernel +#include "exceptions.h" +#include "kernel_manager.h" +#include "nest_impl.h" +#include "universal_data_logger_impl.h" + +// sli +#include "dictutils.h" + +namespace nest +{ + +void +register_eprop_iaf_psc_delta( const std::string& name ) +{ + register_node_model< eprop_iaf_psc_delta >( name ); +} + +/* ---------------------------------------------------------------- + * Recordables map + * ---------------------------------------------------------------- */ + +RecordablesMap< eprop_iaf_psc_delta > eprop_iaf_psc_delta::recordablesMap_; + +template <> +void +RecordablesMap< eprop_iaf_psc_delta >::create() +{ + insert_( names::V_m, &eprop_iaf_psc_delta::get_v_m_ ); + insert_( names::learning_signal, &eprop_iaf_psc_delta::get_learning_signal_ ); + insert_( names::surrogate_gradient, &eprop_iaf_psc_delta::get_surrogate_gradient_ ); +} + +/* ---------------------------------------------------------------- + * Default constructors for parameters, state, and buffers + * ---------------------------------------------------------------- */ + +eprop_iaf_psc_delta::Parameters_::Parameters_() + : tau_m_( 10.0 ) + , C_m_( 250.0 ) + , t_ref_( 2.0 ) + , E_L_( -70.0 ) + , I_e_( 0.0 ) + , V_th_( -55.0 - E_L_ ) + , V_min_( -std::numeric_limits< double >::max() ) + , V_reset_( -70.0 - E_L_ ) + , with_refr_input_( false ) + , c_reg_( 0.0 ) + , f_target_( 0.01 ) + , beta_( 1.0 ) + , gamma_( 0.3 ) + , surrogate_gradient_function_( "piecewise_linear" ) + , kappa_( 0.97 ) + , kappa_reg_( 0.97 ) + , eprop_isi_trace_cutoff_( 1000.0 ) +{ +} + +eprop_iaf_psc_delta::State_::State_() + : i_in_( 0.0 ) + , v_m_( 0.0 ) + , r_( 0 ) + , refr_spikes_buffer_( 0.0 ) + , learning_signal_( 0.0 ) + , surrogate_gradient_( 0.0 ) +{ +} + +eprop_iaf_psc_delta::Buffers_::Buffers_( eprop_iaf_psc_delta& n ) + : logger_( n ) +{ +} + +eprop_iaf_psc_delta::Buffers_::Buffers_( const Buffers_&, eprop_iaf_psc_delta& n ) + : logger_( n ) +{ +} + +/* ---------------------------------------------------------------- + * Getter and setter functions for parameters and state + * ---------------------------------------------------------------- */ + +void +eprop_iaf_psc_delta::Parameters_::get( DictionaryDatum& d ) const +{ + def< double >( d, names::E_L, E_L_ ); + def< double >( d, names::I_e, I_e_ ); + def< double >( d, names::V_th, V_th_ + E_L_ ); + def< double >( d, names::V_reset, V_reset_ + E_L_ ); + def< double >( d, names::V_min, V_min_ + E_L_ ); + def< double >( d, names::C_m, C_m_ ); + def< double >( d, names::tau_m, tau_m_ ); + def< double >( d, names::t_ref, t_ref_ ); + def< bool >( d, names::refractory_input, with_refr_input_ ); + def< double >( d, names::c_reg, c_reg_ ); + def< double >( d, names::f_target, f_target_ ); + def< double >( d, names::beta, beta_ ); + def< double >( d, names::gamma, gamma_ ); + def< std::string >( d, names::surrogate_gradient_function, surrogate_gradient_function_ ); + def< double >( d, names::kappa, kappa_ ); + def< double >( d, names::kappa_reg, kappa_reg_ ); + def< double >( d, names::eprop_isi_trace_cutoff, eprop_isi_trace_cutoff_ ); +} + +double +eprop_iaf_psc_delta::Parameters_::set( const DictionaryDatum& d, Node* node ) +{ + // if leak potential is changed, adjust all variables defined relative to it + const double ELold = E_L_; + updateValueParam< double >( d, names::E_L, E_L_, node ); + const double delta_EL = E_L_ - ELold; + + V_reset_ -= updateValueParam< double >( d, names::V_reset, V_reset_, node ) ? E_L_ : delta_EL; + V_th_ -= updateValueParam< double >( d, names::V_th, V_th_, node ) ? E_L_ : delta_EL; + V_min_ -= updateValueParam< double >( d, names::V_min, V_min_, node ) ? E_L_ : delta_EL; + + updateValueParam< double >( d, names::I_e, I_e_, node ); + updateValueParam< double >( d, names::C_m, C_m_, node ); + updateValueParam< double >( d, names::tau_m, tau_m_, node ); + updateValueParam< double >( d, names::t_ref, t_ref_, node ); + updateValueParam< bool >( d, names::refractory_input, with_refr_input_, node ); + updateValueParam< double >( d, names::c_reg, c_reg_, node ); + + if ( updateValueParam< double >( d, names::f_target, f_target_, node ) ) + { + f_target_ /= 1000.0; // convert from spikes/s to spikes/ms + } + + updateValueParam< double >( d, names::beta, beta_, node ); + updateValueParam< double >( d, names::gamma, gamma_, node ); + updateValueParam< std::string >( d, names::surrogate_gradient_function, surrogate_gradient_function_, node ); + updateValueParam< double >( d, names::kappa, kappa_, node ); + updateValueParam< double >( d, names::kappa_reg, kappa_reg_, node ); + updateValueParam< double >( d, names::eprop_isi_trace_cutoff, eprop_isi_trace_cutoff_, node ); + + if ( V_th_ < V_min_ ) + { + throw BadProperty( "Spike threshold voltage V_th ≥ minimal voltage V_min required." ); + } + + if ( V_reset_ >= V_th_ ) + { + throw BadProperty( "Reset potential must be smaller than threshold." ); + } + + if ( V_reset_ < V_min_ ) + { + throw BadProperty( "Reset voltage V_reset ≥ minimal voltage V_min required." ); + } + + if ( C_m_ <= 0 ) + { + throw BadProperty( "Membrane capacitance C_m > 0 required." ); + } + + if ( t_ref_ < 0 ) + { + throw BadProperty( "Refractory time t_ref ≥ 0 required." ); + } + + if ( tau_m_ <= 0 ) + { + throw BadProperty( "Membrane time constant tau_m > 0 required." ); + } + + if ( c_reg_ < 0 ) + { + throw BadProperty( "Firing rate regularization coefficient c_reg ≥ 0 required." ); + } + + if ( f_target_ < 0 ) + { + throw BadProperty( "Firing rate regularization target rate f_target ≥ 0 required." ); + } + + if ( kappa_ < 0.0 or kappa_ > 1.0 ) + { + throw BadProperty( "Eligibility trace low-pass filter kappa from range [0, 1] required." ); + } + + if ( kappa_reg_ < 0.0 or kappa_reg_ > 1.0 ) + { + throw BadProperty( "Firing rate low-pass filter for regularization kappa_reg from range [0, 1] required." ); + } + + if ( eprop_isi_trace_cutoff_ < 0.0 ) + { + throw BadProperty( "Cutoff of integration of eprop trace between spikes eprop_isi_trace_cutoff ≥ 0 required." ); + } + + return delta_EL; +} + +void +eprop_iaf_psc_delta::State_::get( DictionaryDatum& d, const Parameters_& p ) const +{ + def< double >( d, names::V_m, v_m_ + p.E_L_ ); + def< double >( d, names::surrogate_gradient, surrogate_gradient_ ); + def< double >( d, names::learning_signal, learning_signal_ ); +} + +void +eprop_iaf_psc_delta::State_::set( const DictionaryDatum& d, const Parameters_& p, double delta_EL, Node* node ) +{ + v_m_ -= updateValueParam< double >( d, names::V_m, v_m_, node ) ? p.E_L_ : delta_EL; +} + +/* ---------------------------------------------------------------- + * Default and copy constructor for node + * ---------------------------------------------------------------- */ + +eprop_iaf_psc_delta::eprop_iaf_psc_delta() + : EpropArchivingNodeRecurrent() + , P_() + , S_() + , B_( *this ) +{ + recordablesMap_.create(); +} + +eprop_iaf_psc_delta::eprop_iaf_psc_delta( const eprop_iaf_psc_delta& n ) + : EpropArchivingNodeRecurrent( n ) + , P_( n.P_ ) + , S_( n.S_ ) + , B_( n.B_, *this ) +{ +} + +/* ---------------------------------------------------------------- + * Node initialization functions + * ---------------------------------------------------------------- */ + +void +eprop_iaf_psc_delta::init_buffers_() +{ + B_.spikes_.clear(); // includes resize + B_.currents_.clear(); // includes resize + B_.logger_.reset(); // includes resize +} + +void +eprop_iaf_psc_delta::pre_run_hook() +{ + B_.logger_.init(); // ensures initialization in case multimeter connected after Simulate + + V_.RefractoryCounts_ = Time( Time::ms( P_.t_ref_ ) ).get_steps(); + V_.eprop_isi_trace_cutoff_steps_ = Time( Time::ms( P_.eprop_isi_trace_cutoff_ ) ).get_steps(); + + compute_surrogate_gradient_ = select_surrogate_gradient( P_.surrogate_gradient_function_ ); + + // calculate the entries of the propagator matrix for the evolution of the state vector + + const double dt = Time::get_resolution().get_ms(); + + V_.P_v_m_ = std::exp( -dt / P_.tau_m_ ); + V_.P_i_in_ = P_.tau_m_ / P_.C_m_ * ( 1.0 - V_.P_v_m_ ); +} + +long +eprop_iaf_psc_delta::get_shift() const +{ + return offset_gen_ + delay_in_rec_; +} + +bool +eprop_iaf_psc_delta::is_eprop_recurrent_node() const +{ + return true; +} + +/* ---------------------------------------------------------------- + * Update function + * ---------------------------------------------------------------- */ + +void +eprop_iaf_psc_delta::update( Time const& origin, const long from, const long to ) +{ + const double dt = Time::get_resolution().get_ms(); + + for ( long lag = from; lag < to; ++lag ) + { + const long t = origin.get_steps() + lag; + + const auto z_in = B_.spikes_.get_value( lag ); + + if ( S_.r_ == 0 ) // not refractory, can spike + { + S_.v_m_ = V_.P_i_in_ * ( S_.i_in_ + P_.I_e_ ) + V_.P_v_m_ * S_.v_m_ + z_in; + + if ( P_.with_refr_input_ and S_.refr_spikes_buffer_ != 0.0 ) + { + S_.v_m_ += S_.refr_spikes_buffer_; + S_.refr_spikes_buffer_ = 0.0; + } + + S_.v_m_ = std::max( S_.v_m_, P_.V_min_ ); + } + else + { + if ( P_.with_refr_input_ ) + { + S_.refr_spikes_buffer_ += z_in * std::exp( -S_.r_ * dt / P_.tau_m_ ); + } + + --S_.r_; + } + + double z = 0.0; // spike state variable + + S_.surrogate_gradient_ = ( this->*compute_surrogate_gradient_ )( S_.r_, S_.v_m_, P_.V_th_, P_.beta_, P_.gamma_ ); + + if ( S_.v_m_ >= P_.V_th_ ) + { + S_.r_ = V_.RefractoryCounts_; + S_.v_m_ = P_.V_reset_; + + SpikeEvent se; + kernel().event_delivery_manager.send( *this, se, lag ); + + z = 1.0; + } + + append_new_eprop_history_entry( t ); + write_surrogate_gradient_to_history( t, S_.surrogate_gradient_ ); + write_firing_rate_reg_to_history( t, z, P_.f_target_, P_.kappa_reg_, P_.c_reg_ ); + + S_.learning_signal_ = get_learning_signal_from_history( t, false ); + + S_.i_in_ = B_.currents_.get_value( lag ); + + B_.logger_.record_data( t ); + } +} + +/* ---------------------------------------------------------------- + * Event handling functions + * ---------------------------------------------------------------- */ + +void +eprop_iaf_psc_delta::handle( SpikeEvent& e ) +{ + assert( e.get_delay_steps() > 0 ); + + B_.spikes_.add_value( + e.get_rel_delivery_steps( kernel().simulation_manager.get_slice_origin() ), e.get_weight() * e.get_multiplicity() ); +} + +void +eprop_iaf_psc_delta::handle( CurrentEvent& e ) +{ + assert( e.get_delay_steps() > 0 ); + + B_.currents_.add_value( + e.get_rel_delivery_steps( kernel().simulation_manager.get_slice_origin() ), e.get_weight() * e.get_current() ); +} + +void +eprop_iaf_psc_delta::handle( LearningSignalConnectionEvent& e ) +{ + for ( auto it_event = e.begin(); it_event != e.end(); ) + { + const long time_step = e.get_stamp().get_steps(); + const double weight = e.get_weight(); + const double error_signal = e.get_coeffvalue( it_event ); // get_coeffvalue advances iterator + const double learning_signal = weight * error_signal; + + write_learning_signal_to_history( time_step, learning_signal, false ); + } +} + +void +eprop_iaf_psc_delta::handle( DataLoggingRequest& e ) +{ + B_.logger_.handle( e ); +} + +void +eprop_iaf_psc_delta::compute_gradient( const long t_spike, + const long t_spike_previous, + double& z_previous_buffer, + double& z_bar, + double& e_bar, + double& e_bar_reg, + double& epsilon, + double& weight, + const CommonSynapseProperties& cp, + WeightOptimizer* optimizer ) +{ + double e = 0.0; // eligibility trace + double z = 0.0; // spiking variable + double z_current_buffer = 1.0; // buffer containing the spike that triggered the current integration + double psi = 0.0; // surrogate gradient + double L = 0.0; // learning signal + double firing_rate_reg = 0.0; // firing rate regularization + double grad = 0.0; // gradient + + const EpropSynapseCommonProperties& ecp = static_cast< const EpropSynapseCommonProperties& >( cp ); + const auto optimize_each_step = ( *ecp.optimizer_cp_ ).optimize_each_step_; + + auto eprop_hist_it = get_eprop_history( t_spike_previous - 1 ); + + const long t_compute_until = std::min( t_spike_previous + V_.eprop_isi_trace_cutoff_steps_, t_spike ); + + for ( long t = t_spike_previous; t < t_compute_until; ++t, ++eprop_hist_it ) + { + z = z_previous_buffer; + z_previous_buffer = z_current_buffer; + z_current_buffer = 0.0; + + psi = eprop_hist_it->surrogate_gradient_; + L = eprop_hist_it->learning_signal_; + firing_rate_reg = eprop_hist_it->firing_rate_reg_; + + z_bar = V_.P_v_m_ * z_bar + z; + e = psi * z_bar; + e_bar = P_.kappa_ * e_bar + ( 1.0 - P_.kappa_ ) * e; + e_bar_reg = P_.kappa_reg_ * e_bar_reg + ( 1.0 - P_.kappa_reg_ ) * e; + + if ( optimize_each_step ) + { + grad = L * e_bar + firing_rate_reg * e_bar_reg; + weight = optimizer->optimized_weight( *ecp.optimizer_cp_, t, grad, weight ); + } + else + { + grad += L * e_bar + firing_rate_reg * e_bar_reg; + } + } + + if ( not optimize_each_step ) + { + weight = optimizer->optimized_weight( *ecp.optimizer_cp_, t_compute_until, grad, weight ); + } + + const long cutoff_to_spike_interval = t_spike - t_compute_until; + + if ( cutoff_to_spike_interval > 0 ) + { + z_bar *= std::pow( V_.P_v_m_, cutoff_to_spike_interval ); + e_bar *= std::pow( P_.kappa_, cutoff_to_spike_interval ); + e_bar_reg *= std::pow( P_.kappa_reg_, cutoff_to_spike_interval ); + } +} + +} // namespace nest diff --git a/models/eprop_iaf_psc_delta.h b/models/eprop_iaf_psc_delta.h new file mode 100644 index 0000000000..eb87903bc8 --- /dev/null +++ b/models/eprop_iaf_psc_delta.h @@ -0,0 +1,680 @@ +/* + * eprop_iaf_psc_delta.h + * + * This file is part of NEST. + * + * Copyright (C) 2004 The NEST Initiative + * + * NEST is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * NEST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with NEST. If not, see . + * + */ + +#ifndef EPROP_IAF_PSC_DELTA_H +#define EPROP_IAF_PSC_DELTA_H + +// nestkernel +#include "connection.h" +#include "eprop_archiving_node.h" +#include "eprop_archiving_node_impl.h" +#include "eprop_synapse.h" +#include "event.h" +#include "nest_types.h" +#include "ring_buffer.h" +#include "universal_data_logger.h" + +namespace nest +{ + +/* BeginUserDocs: neuron, e-prop plasticity, current-based, integrate-and-fire + +Short description ++++++++++++++++++ + +Current-based leaky integrate-and-fire neuron model with delta-shaped +postsynaptic currents for e-prop plasticity + +Description ++++++++++++ + +``eprop_iaf_psc_delta`` is an implementation of a leaky integrate-and-fire +neuron model with delta-shaped postsynaptic currents used for eligibility +propagation (e-prop) plasticity. + +E-prop plasticity was originally introduced and implemented in TensorFlow in [1]_. + +.. note:: + The neuron dynamics of the ``eprop_iaf_psc_delta`` model (excluding e-prop + plasticity) are similar to the neuron dynamics of the ``iaf_psc_delta`` model, + with minor differences, such as the propagator of the post-synaptic current + and the voltage reset upon a spike. + +The membrane voltage time course :math:`v_j^t` of the neuron :math:`j` is given by: + +.. math:: + v_j^t &= \alpha v_j^{t-1} + \sum_{i \neq j} W_{ji}^\text{rec} z_i^{t-1} + + \sum_i W_{ji}^\text{in} x_i^t \,, \\ + \alpha &= e^{ -\frac{ \Delta t }{ \tau_\text{m} } } \,, \\ + +where :math:`W_{ji}^\text{rec}` and :math:`W_{ji}^\text{in}` are the recurrent and +input synaptic weight matrices, and :math:`z_i^{t-1}` is the recurrent presynaptic +state variable, while :math:`x_i^t` represents the input at time :math:`t`. + +Descriptions of further parameters and variables can be found in the table below. + +The spike state variable is expressed by a Heaviside function: + +.. math:: + z_j^t = H \left( v_j^t - v_\text{th} \right) \,. \\ + +If the membrane voltage crosses the threshold voltage :math:`v_\text{th}`, a spike is +emitted and the membrane voltage is reset to :math:`v_\text{reset}`. After the time step +of the spike emission, the neuron is not able to spike for an absolute refractory period +:math:`t_\text{ref}` during which the membrane potential stays clamped to the reset voltage +:math:`v_\text{reset}`, thus + +.. math:: + v_m = v_\text{reset} \quad \text{for} \quad t_\text{spk} \leq t \leq t_\text{spk} + t_\text{ref} \,. + +Spikes arriving while the neuron is refractory are discarded by default. However, +if ``refractory_input`` is set to ``True`` they are damped for each time step +until the end of the refractory period and then added to the membrane voltage. + +An additional state variable and the corresponding differential equation +represents a piecewise constant external current. + +See the documentation on the :doc:`iaf_psc_delta<../models/iaf_psc_delta/>` neuron model +for more information on the integration of the subthreshold dynamics. + +The change of the synaptic weight is calculated from the gradient :math:`g^t` of +the loss :math:`E^t` with respect to the synaptic weight :math:`W_{ji}`: +:math:`\frac{ \text{d} E^t }{ \text{d} W_{ij} }` +which depends on the presynaptic +spikes :math:`z_i^{t-2}`, the surrogate gradient or pseudo-derivative +of the spike state variable with respect to the postsynaptic membrane +voltage :math:`\psi_j^{t-1}` (the product of which forms the eligibility +trace :math:`e_{ji}^{t-1}`), and the learning signal :math:`L_j^t` emitted +by the readout neurons. + +Surrogate gradients help overcome the challenge of the spiking function's +non-differentiability, facilitating the use of gradient-based learning +techniques such as e-prop. The non-existent derivative of the spiking +variable with respect to the membrane voltage, +:math:`\frac{\partial z^t_j}{ \partial v^t_j}`, can be effectively +replaced with a variety of surrogate gradient functions, as detailed in +various studies (see, e.g., [3]_). NEST currently provides four +different surrogate gradient functions: + +1. A piecewise linear function used among others in [1]_: + +.. math:: + \psi_j^t = \frac{ \gamma }{ v_\text{th} } \text{max} + \left( 0, 1-\beta \left| \frac{ v_j^t - v_\text{th} }{ v_\text{th} }\right| \right) \,. \\ + +2. An exponential function used in [4]_: + +.. math:: + \psi_j^t = \gamma \exp \left( -\beta \left| v_j^t - v_\text{th} \right| \right) \,. \\ + +3. The derivative of a fast sigmoid function used in [5]_: + +.. math:: + \psi_j^t = \gamma \left( 1 + \beta \left| v_j^t - v_\text{th} \right| \right)^2 \,. \\ + +4. An arctan function used in [6]_: + +.. math:: + \psi_j^t = \frac{\gamma}{\pi} \frac{1}{ 1 + \left( \beta \pi \left( v_j^t - v_\text{th} \right) \right)^2 } \,. \\ + +In the interval between two presynaptic spikes, the gradient is calculated +at each time step until the cutoff time point. This computation occurs over +the time range: + +:math:`t \in \left[ t_\text{spk,prev}, \min \left( t_\text{spk,prev} + \Delta t_\text{c}, t_\text{spk,curr} \right) +\right]`. + +Here, :math:`t_\text{spk,prev}` represents the time of the previous spike that +passed the synapse, while :math:`t_\text{spk,curr}` is the time of the +current spike, which triggers the application of the learning rule and the +subsequent synaptic weight update. The cutoff :math:`\Delta t_\text{c}` +defines the maximum allowable interval for integration between spikes. +The expression for the gradient is given by: + +.. math:: + \frac{ \text{d} E^t }{ \text{d} W_{ji} } &= L_j^t \bar{e}_{ji}^{t-1} \,, \\ + e_{ji}^{t-1} &= \psi_j^{t-1} \bar{z}_i^{t-2} \,, \\ + +The eligibility trace and the presynaptic spike trains are low-pass filtered +with the following exponential kernels: + +.. math:: + \bar{e}_{ji}^t &= \mathcal{F}_\kappa \left( e_{ji}^t \right) + = \kappa \bar{e}_{ji}^{t-1} + \left( 1 - \kappa \right) e_{ji}^t \,, \\ + \bar{z}_i^t &= \mathcal{F}_\alpha \left( z_{i}^t \right)= \alpha \bar{z}_i^{t-1} + z_i^t \,. \\ + +Furthermore, a firing rate regularization mechanism keeps the exponential moving average of the postsynaptic +neuron's firing rate :math:`f_j^{\text{ema},t}` close to a target firing rate +:math:`f^\text{target}`. The gradient :math:`g_\text{reg}^t` of the regularization loss :math:`E_\text{reg}^t` +with respect to the synaptic weight :math:`W_{ji}` is given by: + +.. math:: + \frac{ \text{d} E_\text{reg}^t }{ \text{d} W_{ji}} + &\approx c_\text{reg} \left( f^{\text{ema},t}_j - f^\text{target} \right) \bar{e}_{ji}^t \,, \\ + f^{\text{ema},t}_j &= \mathcal{F}_{\kappa_\text{reg}} \left( \frac{z_j^t}{\Delta t} \right) + = \kappa_\text{reg} f^{\text{ema},t-1}_j + \left( 1 - \kappa_\text{reg} \right) \frac{z_j^t}{\Delta t} \,, \\ + +where :math:`c_\text{reg}` is a constant scaling factor. + +The overall gradient is given by the addition of the two gradients. + +As a last step for every round in the loop over the time steps :math:`t`, the new weight is retrieved by feeding the +current gradient :math:`g^t` to the optimizer (see :doc:`weight_optimizer<../models/weight_optimizer/>` +for more information on the available optimizers): + +.. math:: + w^t = \text{optimizer} \left( t, g^t, w^{t-1} \right) \,. \\ + +After the loop has terminated, the filtered dynamic variables of e-prop are propagated from the end of the cutoff until +the next spike: + +.. math:: + p &= \text{max} \left( 0, t_\text{s}^{t} - \left( t_\text{s}^{t-1} + {\Delta t}_\text{c} \right) \right) \,, \\ + \bar{e}_{ji}^{t+p} &= \bar{e}_{ji}^t \kappa^p \,, \\ + \bar{z}_i^{t+p} &= \bar{z}_i^t \alpha^p \,. \\ + +For more information on the implementation details of the neuron model, see [7]_ and [8]_. + +For more information on e-prop plasticity, see the documentation on the other e-prop models: + + * :doc:`eprop_iaf_psc_delta_adapt<../models/eprop_iaf_psc_delta_adapt/>` + * :doc:`eprop_readout<../models/eprop_readout/>` + * :doc:`eprop_synapse<../models/eprop_synapse/>` + * :doc:`eprop_learning_signal_connection<../models/eprop_learning_signal_connection/>` + +Details on the event-based NEST implementation of e-prop can be found in [2]_. + +Parameters +++++++++++ + +The following parameters can be set in the status dictionary. + +=========================== ======= ======================= ================ =================================== +**Neuron parameters** +---------------------------------------------------------------------------------------------------------------- +Parameter Unit Math equivalent Default Description +=========================== ======= ======================= ================ =================================== +``C_m`` pF :math:`C_\text{m}` 250.0 Capacitance of the membrane +``E_L`` mV :math:`E_\text{L}` -70.0 Leak / resting membrane potential +``I_e`` pA :math:`I_\text{e}` 0.0 Constant external input current +``t_ref`` ms :math:`t_\text{ref}` 2.0 Duration of the refractory period +``tau_m`` ms :math:`\tau_\text{m}` 10.0 Time constant of the membrane +``V_min`` mV :math:`v_\text{min}` negative maximum Absolute lower bound of the + value membrane voltage + representable + by a ``double`` + type in C++ +``V_th`` mV :math:`v_\text{th}` -55.0 Spike threshold voltage +``V_reset`` mV :math:`v_\text{reset}` -70.0 Reset voltage +``refractory_input`` Boolean ``False`` If ``True``, spikes arriving during + the refractory period are damped + until it ends and then added to the + membrane voltage +=========================== ======= ======================= ================ =================================== + +=============================== ======= =========================== ================== ========================= +**E-prop parameters** +---------------------------------------------------------------------------------------------------------------- +Parameter Unit Math equivalent Default Description +=============================== ======= =========================== ================== ========================= +``c_reg`` :math:`c_\text{reg}` 0.0 Coefficient of firing + rate regularization +``eprop_isi_trace_cutoff`` ms :math:`{\Delta t}_\text{c}` maximum value Cutoff for integration of + representable e-prop update between two + by a ``long`` spikes + type in C++ +``f_target`` Hz :math:`f^\text{target}` 10.0 Target firing rate of + rate regularization +``kappa`` :math:`\kappa` 0.97 Low-pass filter of the + eligibility trace +``kappa_reg`` :math:`\kappa_\text{reg}` 0.97 Low-pass filter of the + firing rate for + regularization +``beta`` :math:`\beta` 1.0 Width scaling of + surrogate gradient / + pseudo-derivative of + membrane voltage +``gamma`` :math:`\gamma` 0.3 Height scaling of + surrogate gradient / + pseudo-derivative of + membrane voltage +``surrogate_gradient_function`` :math:`\psi` "piecewise_linear" Surrogate gradient / + pseudo-derivative + function + ["piecewise_linear", + "exponential", + "fast_sigmoid_derivative" + , "arctan"] +=============================== ======= =========================== ================== ========================= + +Recordables ++++++++++++ + +The following state variables evolve during simulation and can be recorded. + +================== ==== =============== ============= ======================== +**Neuron state variables and recordables** +------------------------------------------------------------------------------ +State variable Unit Math equivalent Initial value Description +================== ==== =============== ============= ======================== +``V_m`` mV :math:`v_j` -70.0 Membrane voltage +================== ==== =============== ============= ======================== + +====================== ==== =============== ============= ========================================= +**E-prop state variables and recordables** +--------------------------------------------------------------------------------------------------- +State variable Unit Math equivalent Initial value Description +====================== ==== =============== ============= ========================================= +``learning_signal`` pA :math:`L_j` 0.0 Learning signal +``surrogate_gradient`` :math:`\psi_j` 0.0 Surrogate gradient / pseudo-derivative of + membrane voltage +====================== ==== =============== ============= ========================================= + +Usage ++++++ + +This model can only be used in combination with the other e-prop models +and the network architecture requires specific wiring, input, and output. +The usage is demonstrated in several +:doc:`supervised regression and classification tasks <../auto_examples/eprop_plasticity/index>` +reproducing among others the original proof-of-concept tasks in [1]_. + +References +++++++++++ + +.. [1] Bellec G, Scherr F, Subramoney F, Hajek E, Salaj D, Legenstein R, + Maass W (2020). A solution to the learning dilemma for recurrent + networks of spiking neurons. Nature Communications, 11:3625. + https://doi.org/10.1038/s41467-020-17236-y + +.. [2] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Plesser HE, + Dahmen D, Bolten M, Van Albada SJ*, Diesmann M*. Event-based + implementation of eligibility propagation (in preparation) + +.. [3] Neftci EO, Mostafa H, Zenke F (2019). Surrogate Gradient Learning in + Spiking Neural Networks. IEEE Signal Processing Magazine, 36(6), 51-63. + https://doi.org/10.1109/MSP.2019.2931595 + +.. [4] Shrestha SB, Orchard G (2018). SLAYER: Spike Layer Error Reassignment in + Time. Advances in Neural Information Processing Systems, 31:1412-1421. + https://proceedings.neurips.cc/paper_files/paper/2018/hash/82.. rubric:: References + +.. [5] Zenke F, Ganguli S (2018). SuperSpike: Supervised Learning in Multilayer + Spiking Neural Networks. Neural Computation, 30:1514–1541. + https://doi.org/10.1162/neco_a_01086 + +.. [6] Fang W, Yu Z, Chen Y, Huang T, Masquelier T, Tian Y (2021). Deep residual + learning in spiking neural networks. Advances in Neural Information + Processing Systems, 34:21056–21069. + https://proceedings.neurips.cc/paper/2021/hash/afe434653a898da20044041262b3ac74-Abstract.html + +.. [7] Rotter S, Diesmann M (1999). Exact simulation of time-invariant linear + systems with applications to neuronal modeling. Biological Cybernetics + 81:381-402. + https://doi.org/10.1007/s004220050570 + +.. [8] Diesmann M, Gewaltig MO, Rotter S, Aertsen A (2001). State space analysis + of synchronous spiking in cortical neural networks. Neurocomputing + 38-40:565-571. + https://doi.org/10.1016/S0925-2312(01)00409-X + +Sends ++++++ + +SpikeEvent + +Receives +++++++++ + +SpikeEvent, CurrentEvent, LearningSignalConnectionEvent, DataLoggingRequest + +See also +++++++++ + +Examples using this model ++++++++++++++++++++++++++ + +.. listexamples:: eprop_iaf_psc_delta + +EndUserDocs */ + +void register_eprop_iaf_psc_delta( const std::string& name ); + +/** + * @brief Class implementing an adaptive LIF neuron model for e-prop plasticity with additional biological features. + * + * Class implementing a current-based leaky integrate-and-fire neuron model with delta-shaped postsynaptic currents + * and spike threshold adaptation for e-prop plasticity according to Bellec et al. (2020) with additional biological + * features described in Korcsak-Gorzo, Stapmanns, and Espinoza Valverde et al. (in preparation). + */ +class eprop_iaf_psc_delta : public EpropArchivingNodeRecurrent +{ + +public: + //! Default constructor. + eprop_iaf_psc_delta(); + + //! Copy constructor. + eprop_iaf_psc_delta( const eprop_iaf_psc_delta& ); + + using Node::handle; + using Node::handles_test_event; + + size_t send_test_event( Node&, size_t, synindex, bool ) override; + + void handle( SpikeEvent& ) override; + void handle( CurrentEvent& ) override; + void handle( LearningSignalConnectionEvent& ) override; + void handle( DataLoggingRequest& ) override; + + size_t handles_test_event( SpikeEvent&, size_t ) override; + size_t handles_test_event( CurrentEvent&, size_t ) override; + size_t handles_test_event( LearningSignalConnectionEvent&, size_t ) override; + size_t handles_test_event( DataLoggingRequest&, size_t ) override; + + void get_status( DictionaryDatum& ) const override; + void set_status( const DictionaryDatum& ) override; + +private: + void init_buffers_() override; + void pre_run_hook() override; + + void update( Time const&, const long, const long ) override; + + void compute_gradient( const long, + const long, + double&, + double&, + double&, + double&, + double&, + double&, + const CommonSynapseProperties&, + WeightOptimizer* ) override; + + long get_shift() const override; + bool is_eprop_recurrent_node() const override; + long get_eprop_isi_trace_cutoff() const override; + + //! Pointer to member function selected for computing the surrogate gradient. + surrogate_gradient_function compute_surrogate_gradient_; + + //! Map for storing a static set of recordables. + friend class RecordablesMap< eprop_iaf_psc_delta >; + + //! Logger for universal data supporting the data logging request / reply mechanism. Populated with a recordables map. + friend class UniversalDataLogger< eprop_iaf_psc_delta >; + + //! Structure of parameters. + struct Parameters_ + { + //! Time constant of the membrane (ms). + double tau_m_; + + //! Capacitance of the membrane (pF). + double C_m_; + + //! Duration of the refractory period (ms). + double t_ref_; + + //! Leak / resting membrane potential (mV). + double E_L_; + + //! Constant external input current (pA). + double I_e_; + + //! Spike threshold voltage relative to the leak membrane potential (mV). + double V_th_; + + //! Absolute lower bound of the membrane voltage relative to the leak membrane potential (mV). + double V_min_; + + //! Reset voltage relative to the leak membrane potential (mV). + double V_reset_; + + //! If True, count spikes arriving during the refractory period. + bool with_refr_input_; + + //! Coefficient of firing rate regularization. + double c_reg_; + + //! Target firing rate of rate regularization (spikes/s). + double f_target_; + + //! Width scaling of surrogate gradient / pseudo-derivative of membrane voltage. + double beta_; + + //! Height scaling of surrogate gradient / pseudo-derivative of membrane voltage. + double gamma_; + + //! Surrogate gradient / pseudo-derivative function of the membrane voltage ["piecewise_linear", "exponential", + //! "fast_sigmoid_derivative", "arctan"] + std::string surrogate_gradient_function_; + + //! Low-pass filter of the eligibility trace. + double kappa_; + + //! Low-pass filter of the firing rate for regularization. + double kappa_reg_; + + //! Time interval from the previous spike until the cutoff of e-prop update integration between two spikes (ms). + double eprop_isi_trace_cutoff_; + + //! Default constructor. + Parameters_(); + + //! Get the parameters and their values. + void get( DictionaryDatum& ) const; + + //! Set the parameters and throw errors in case of invalid values. + double set( const DictionaryDatum&, Node* ); + }; + + //! Structure of state variables. + struct State_ + { + //! Input current (pA). + double i_in_; + + //! Membrane voltage relative to the leak membrane potential (mV). + double v_m_; + + //! Number of remaining refractory steps. + int r_; + + //! Count of spikes arriving during refractory period discounted for decay until end of refractory period. + double refr_spikes_buffer_; + + //! Learning signal. Sum of weighted error signals coming from the readout neurons. + double learning_signal_; + + //! Surrogate gradient / pseudo-derivative of the membrane voltage. + double surrogate_gradient_; + + //! Default constructor. + State_(); + + //! Get the state variables and their values. + void get( DictionaryDatum&, const Parameters_& ) const; + + //! Set the state variables. + void set( const DictionaryDatum&, const Parameters_&, double, Node* ); + }; + + //! Structure of buffers. + struct Buffers_ + { + //! Default constructor. + Buffers_( eprop_iaf_psc_delta& ); + + //! Copy constructor. + Buffers_( const Buffers_&, eprop_iaf_psc_delta& ); + + //! Buffer for incoming spikes. + RingBuffer spikes_; + + //! Buffer for incoming currents. + RingBuffer currents_; + + //! Logger for universal data. + UniversalDataLogger< eprop_iaf_psc_delta > logger_; + }; + + //! Structure of internal variables. + struct Variables_ + { + //! Propagator matrix entry for evolving the membrane voltage (mathematical symbol "alpha" in user documentation). + double P_v_m_; + + //! Propagator matrix entry for evolving the incoming currents. + double P_i_in_; + + //! Total refractory steps. + int RefractoryCounts_; + + //! Time steps from the previous spike until the cutoff of e-prop update integration between two spikes. + long eprop_isi_trace_cutoff_steps_; + }; + + //! Get the current value of the membrane voltage. + double + get_v_m_() const + { + return S_.v_m_ + P_.E_L_; + } + + //! Get the current value of the surrogate gradient. + double + get_surrogate_gradient_() const + { + return S_.surrogate_gradient_; + } + + //! Get the current value of the learning signal. + double + get_learning_signal_() const + { + return S_.learning_signal_; + } + + // the order in which the structure instances are defined is important for speed + + //! Structure of parameters. + Parameters_ P_; + + //! Structure of state variables. + State_ S_; + + //! Structure of internal variables. + Variables_ V_; + + //! Structure of buffers. + Buffers_ B_; + + //! Map storing a static set of recordables. + static RecordablesMap< eprop_iaf_psc_delta > recordablesMap_; +}; + +inline long +eprop_iaf_psc_delta::get_eprop_isi_trace_cutoff() const +{ + return V_.eprop_isi_trace_cutoff_steps_; +} + +inline size_t +eprop_iaf_psc_delta::send_test_event( Node& target, size_t receptor_type, synindex, bool ) +{ + SpikeEvent e; + e.set_sender( *this ); + return target.handles_test_event( e, receptor_type ); +} + +inline size_t +eprop_iaf_psc_delta::handles_test_event( SpikeEvent&, size_t receptor_type ) +{ + if ( receptor_type != 0 ) + { + throw UnknownReceptorType( receptor_type, get_name() ); + } + + return 0; +} + +inline size_t +eprop_iaf_psc_delta::handles_test_event( CurrentEvent&, size_t receptor_type ) +{ + if ( receptor_type != 0 ) + { + throw UnknownReceptorType( receptor_type, get_name() ); + } + + return 0; +} + +inline size_t +eprop_iaf_psc_delta::handles_test_event( LearningSignalConnectionEvent&, size_t receptor_type ) +{ + if ( receptor_type != 0 ) + { + throw UnknownReceptorType( receptor_type, get_name() ); + } + + return 0; +} + +inline size_t +eprop_iaf_psc_delta::handles_test_event( DataLoggingRequest& dlr, size_t receptor_type ) +{ + if ( receptor_type != 0 ) + { + throw UnknownReceptorType( receptor_type, get_name() ); + } + + return B_.logger_.connect_logging_device( dlr, recordablesMap_ ); +} + +inline void +eprop_iaf_psc_delta::get_status( DictionaryDatum& d ) const +{ + P_.get( d ); + S_.get( d, P_ ); + ( *d )[ names::recordables ] = recordablesMap_.get_list(); +} + +inline void +eprop_iaf_psc_delta::set_status( const DictionaryDatum& d ) +{ + // temporary copies in case of errors + Parameters_ ptmp = P_; + State_ stmp = S_; + + // make sure that ptmp and stmp consistent - throw BadProperty if not + const double delta_EL = ptmp.set( d, this ); + stmp.set( d, ptmp, delta_EL, this ); + + P_ = ptmp; + S_ = stmp; +} + +} // namespace nest + +#endif // EPROP_IAF_PSC_DELTA_H diff --git a/models/eprop_iaf_psc_delta_adapt.cpp b/models/eprop_iaf_psc_delta_adapt.cpp new file mode 100644 index 0000000000..7c8cf8adf8 --- /dev/null +++ b/models/eprop_iaf_psc_delta_adapt.cpp @@ -0,0 +1,519 @@ +/* + * eprop_iaf_psc_delta_adapt.cpp + * + * This file is part of NEST. + * + * Copyright (C) 2004 The NEST Initiative + * + * NEST is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * NEST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with NEST. If not, see . + * + */ + +// nest models +#include "eprop_iaf_psc_delta_adapt.h" + +// C++ +#include + +// libnestutil +#include "dict_util.h" +#include "numerics.h" + +// nestkernel +#include "exceptions.h" +#include "kernel_manager.h" +#include "nest_impl.h" +#include "universal_data_logger_impl.h" + +// sli +#include "dictutils.h" + +namespace nest +{ + +void +register_eprop_iaf_psc_delta_adapt( const std::string& name ) +{ + register_node_model< eprop_iaf_psc_delta_adapt >( name ); +} + +/* ---------------------------------------------------------------- + * Recordables map + * ---------------------------------------------------------------- */ + +RecordablesMap< eprop_iaf_psc_delta_adapt > eprop_iaf_psc_delta_adapt::recordablesMap_; + +template <> +void +RecordablesMap< eprop_iaf_psc_delta_adapt >::create() +{ + insert_( names::V_m, &eprop_iaf_psc_delta_adapt::get_v_m_ ); + insert_( names::adaptation, &eprop_iaf_psc_delta_adapt::get_adaptation_ ); + insert_( names::V_th_adapt, &eprop_iaf_psc_delta_adapt::get_v_th_adapt_ ); + insert_( names::learning_signal, &eprop_iaf_psc_delta_adapt::get_learning_signal_ ); + insert_( names::surrogate_gradient, &eprop_iaf_psc_delta_adapt::get_surrogate_gradient_ ); +} + +/* ---------------------------------------------------------------- + * Default constructors for parameters, state, and buffers + * ---------------------------------------------------------------- */ + +eprop_iaf_psc_delta_adapt::Parameters_::Parameters_() + : tau_m_( 10.0 ) + , C_m_( 250.0 ) + , t_ref_( 2.0 ) + , E_L_( -70.0 ) + , I_e_( 0.0 ) + , V_th_( -55.0 - E_L_ ) + , V_min_( -std::numeric_limits< double >::max() ) + , V_reset_( -70.0 - E_L_ ) + , with_refr_input_( false ) + , adapt_beta_( 1.0 ) + , adapt_tau_( 10.0 ) + , c_reg_( 0.0 ) + , f_target_( 0.01 ) + , beta_( 1.0 ) + , gamma_( 0.3 ) + , surrogate_gradient_function_( "piecewise_linear" ) + , kappa_( 0.97 ) + , kappa_reg_( 0.97 ) + , eprop_isi_trace_cutoff_( 1000.0 ) +{ +} + +eprop_iaf_psc_delta_adapt::State_::State_() + : i_in_( 0.0 ) + , v_m_( 0.0 ) + , r_( 0 ) + , refr_spikes_buffer_( 0.0 ) + , z_( 0.0 ) + , adapt_( 0.0 ) + , v_th_adapt_( 15.0 ) + , learning_signal_( 0.0 ) + , surrogate_gradient_( 0.0 ) +{ +} + +eprop_iaf_psc_delta_adapt::Buffers_::Buffers_( eprop_iaf_psc_delta_adapt& n ) + : logger_( n ) +{ +} + +eprop_iaf_psc_delta_adapt::Buffers_::Buffers_( const Buffers_&, eprop_iaf_psc_delta_adapt& n ) + : logger_( n ) +{ +} + +/* ---------------------------------------------------------------- + * Getter and setter functions for parameters and state + * ---------------------------------------------------------------- */ + +void +eprop_iaf_psc_delta_adapt::Parameters_::get( DictionaryDatum& d ) const +{ + def< double >( d, names::E_L, E_L_ ); + def< double >( d, names::I_e, I_e_ ); + def< double >( d, names::V_th, V_th_ + E_L_ ); + def< double >( d, names::V_reset, V_reset_ + E_L_ ); + def< double >( d, names::V_min, V_min_ + E_L_ ); + def< double >( d, names::C_m, C_m_ ); + def< double >( d, names::tau_m, tau_m_ ); + def< double >( d, names::t_ref, t_ref_ ); + def< bool >( d, names::refractory_input, with_refr_input_ ); + def< double >( d, names::adapt_beta, adapt_beta_ ); + def< double >( d, names::adapt_tau, adapt_tau_ ); + def< double >( d, names::c_reg, c_reg_ ); + def< double >( d, names::f_target, f_target_ ); + def< double >( d, names::beta, beta_ ); + def< double >( d, names::gamma, gamma_ ); + def< std::string >( d, names::surrogate_gradient_function, surrogate_gradient_function_ ); + def< double >( d, names::kappa, kappa_ ); + def< double >( d, names::kappa_reg, kappa_reg_ ); + def< double >( d, names::eprop_isi_trace_cutoff, eprop_isi_trace_cutoff_ ); +} + +double +eprop_iaf_psc_delta_adapt::Parameters_::set( const DictionaryDatum& d, Node* node ) +{ + // if leak potential is changed, adjust all variables defined relative to it + const double ELold = E_L_; + updateValueParam< double >( d, names::E_L, E_L_, node ); + const double delta_EL = E_L_ - ELold; + + V_reset_ -= updateValueParam< double >( d, names::V_reset, V_reset_, node ) ? E_L_ : delta_EL; + V_th_ -= updateValueParam< double >( d, names::V_th, V_th_, node ) ? E_L_ : delta_EL; + V_min_ -= updateValueParam< double >( d, names::V_min, V_min_, node ) ? E_L_ : delta_EL; + + updateValueParam< double >( d, names::I_e, I_e_, node ); + updateValueParam< double >( d, names::C_m, C_m_, node ); + updateValueParam< double >( d, names::tau_m, tau_m_, node ); + updateValueParam< double >( d, names::t_ref, t_ref_, node ); + updateValueParam< bool >( d, names::refractory_input, with_refr_input_, node ); + updateValueParam< double >( d, names::adapt_beta, adapt_beta_, node ); + updateValueParam< double >( d, names::adapt_tau, adapt_tau_, node ); + updateValueParam< double >( d, names::c_reg, c_reg_, node ); + + if ( updateValueParam< double >( d, names::f_target, f_target_, node ) ) + { + f_target_ /= 1000.0; // convert from spikes/s to spikes/ms + } + + updateValueParam< double >( d, names::beta, beta_, node ); + updateValueParam< double >( d, names::gamma, gamma_, node ); + updateValueParam< std::string >( d, names::surrogate_gradient_function, surrogate_gradient_function_, node ); + updateValueParam< double >( d, names::kappa, kappa_, node ); + updateValueParam< double >( d, names::kappa_reg, kappa_reg_, node ); + updateValueParam< double >( d, names::eprop_isi_trace_cutoff, eprop_isi_trace_cutoff_, node ); + + if ( V_th_ < V_min_ ) + { + throw BadProperty( "Spike threshold voltage V_th ≥ minimal voltage V_min required." ); + } + + if ( V_reset_ >= V_th_ ) + { + throw BadProperty( "Reset potential must be smaller than threshold." ); + } + + if ( V_reset_ < V_min_ ) + { + throw BadProperty( "Reset voltage V_reset ≥ minimal voltage V_min required." ); + } + + if ( C_m_ <= 0 ) + { + throw BadProperty( "Membrane capacitance C_m > 0 required." ); + } + + if ( t_ref_ < 0 ) + { + throw BadProperty( "Refractory time t_ref ≥ 0 required." ); + } + + if ( tau_m_ <= 0 ) + { + throw BadProperty( "Membrane time constant tau_m > 0 required." ); + } + + if ( adapt_beta_ < 0 ) + { + throw BadProperty( "Threshold adaptation prefactor adapt_beta ≥ 0 required." ); + } + + if ( adapt_tau_ <= 0 ) + { + throw BadProperty( "Threshold adaptation time constant adapt_tau > 0 required." ); + } + + if ( c_reg_ < 0 ) + { + throw BadProperty( "Firing rate regularization coefficient c_reg ≥ 0 required." ); + } + + if ( f_target_ < 0 ) + { + throw BadProperty( "Firing rate regularization target rate f_target ≥ 0 required." ); + } + + if ( kappa_ < 0.0 or kappa_ > 1.0 ) + { + throw BadProperty( "Eligibility trace low-pass filter kappa from range [0, 1] required." ); + } + + if ( kappa_reg_ < 0.0 or kappa_reg_ > 1.0 ) + { + throw BadProperty( "Firing rate low-pass filter for regularization kappa_reg from range [0, 1] required." ); + } + + if ( eprop_isi_trace_cutoff_ < 0.0 ) + { + throw BadProperty( "Cutoff of integration of eprop trace between spikes eprop_isi_trace_cutoff ≥ 0 required." ); + } + + return delta_EL; +} + +void +eprop_iaf_psc_delta_adapt::State_::get( DictionaryDatum& d, const Parameters_& p ) const +{ + def< double >( d, names::V_m, v_m_ + p.E_L_ ); + def< double >( d, names::adaptation, adapt_ ); + def< double >( d, names::V_th_adapt, v_th_adapt_ + p.E_L_ ); + def< double >( d, names::surrogate_gradient, surrogate_gradient_ ); + def< double >( d, names::learning_signal, learning_signal_ ); +} + +void +eprop_iaf_psc_delta_adapt::State_::set( const DictionaryDatum& d, const Parameters_& p, double delta_EL, Node* node ) +{ + v_m_ -= updateValueParam< double >( d, names::V_m, v_m_, node ) ? p.E_L_ : delta_EL; + + // adaptive threshold can only be set indirectly via the adaptation variable + if ( updateValueParam< double >( d, names::adaptation, adapt_, node ) ) + { + // if E_L changed in this SetStatus call, p.V_th_ has been adjusted and no further action is needed + v_th_adapt_ = p.V_th_ + p.adapt_beta_ * adapt_; + } + else + { + // adjust voltage to change in E_L + v_th_adapt_ -= delta_EL; + } +} + +/* ---------------------------------------------------------------- + * Default and copy constructor for node + * ---------------------------------------------------------------- */ + +eprop_iaf_psc_delta_adapt::eprop_iaf_psc_delta_adapt() + : EpropArchivingNodeRecurrent() + , P_() + , S_() + , B_( *this ) +{ + recordablesMap_.create(); +} + +eprop_iaf_psc_delta_adapt::eprop_iaf_psc_delta_adapt( const eprop_iaf_psc_delta_adapt& n ) + : EpropArchivingNodeRecurrent( n ) + , P_( n.P_ ) + , S_( n.S_ ) + , B_( n.B_, *this ) +{ +} + +/* ---------------------------------------------------------------- + * Node initialization functions + * ---------------------------------------------------------------- */ + +void +eprop_iaf_psc_delta_adapt::init_buffers_() +{ + B_.spikes_.clear(); // includes resize + B_.currents_.clear(); // includes resize + B_.logger_.reset(); // includes resize +} + +void +eprop_iaf_psc_delta_adapt::pre_run_hook() +{ + B_.logger_.init(); // ensures initialization in case multimeter connected after Simulate + + V_.RefractoryCounts_ = Time( Time::ms( P_.t_ref_ ) ).get_steps(); + V_.eprop_isi_trace_cutoff_steps_ = Time( Time::ms( P_.eprop_isi_trace_cutoff_ ) ).get_steps(); + + compute_surrogate_gradient_ = select_surrogate_gradient( P_.surrogate_gradient_function_ ); + + // calculate the entries of the propagator matrix for the evolution of the state vector + + const double dt = Time::get_resolution().get_ms(); + + V_.P_v_m_ = std::exp( -dt / P_.tau_m_ ); + V_.P_i_in_ = P_.tau_m_ / P_.C_m_ * ( 1.0 - V_.P_v_m_ ); + V_.P_adapt_ = std::exp( -dt / P_.adapt_tau_ ); +} + +long +eprop_iaf_psc_delta_adapt::get_shift() const +{ + return offset_gen_ + delay_in_rec_; +} + +bool +eprop_iaf_psc_delta_adapt::is_eprop_recurrent_node() const +{ + return true; +} + +/* ---------------------------------------------------------------- + * Update function + * ---------------------------------------------------------------- */ + +void +eprop_iaf_psc_delta_adapt::update( Time const& origin, const long from, const long to ) +{ + const double dt = Time::get_resolution().get_ms(); + + for ( long lag = from; lag < to; ++lag ) + { + const long t = origin.get_steps() + lag; + + const auto z_in = B_.spikes_.get_value( lag ); + + if ( S_.r_ == 0 ) // not refractory, can spike + { + S_.v_m_ = V_.P_i_in_ * ( S_.i_in_ + P_.I_e_ ) + V_.P_v_m_ * S_.v_m_ + z_in; + + if ( P_.with_refr_input_ and S_.refr_spikes_buffer_ != 0.0 ) + { + S_.v_m_ += S_.refr_spikes_buffer_; + S_.refr_spikes_buffer_ = 0.0; + } + + S_.v_m_ = std::max( S_.v_m_, P_.V_min_ ); + + S_.adapt_ = V_.P_adapt_ * S_.adapt_ + S_.z_; + S_.v_th_adapt_ = P_.V_th_ + P_.adapt_beta_ * S_.adapt_; + } + else + { + if ( P_.with_refr_input_ ) + { + S_.refr_spikes_buffer_ += z_in * std::exp( -S_.r_ * dt / P_.tau_m_ ); + } + + --S_.r_; + } + + S_.z_ = 0.0; + + S_.surrogate_gradient_ = + ( this->*compute_surrogate_gradient_ )( S_.r_, S_.v_m_, S_.v_th_adapt_, P_.beta_, P_.gamma_ ); + + if ( S_.v_m_ >= S_.v_th_adapt_ ) + { + S_.r_ = V_.RefractoryCounts_; + S_.v_m_ = P_.V_reset_; + + SpikeEvent se; + kernel().event_delivery_manager.send( *this, se, lag ); + + S_.z_ = 1.0; + } + + append_new_eprop_history_entry( t ); + write_surrogate_gradient_to_history( t, S_.surrogate_gradient_ ); + write_firing_rate_reg_to_history( t, S_.z_, P_.f_target_, P_.kappa_reg_, P_.c_reg_ ); + + S_.learning_signal_ = get_learning_signal_from_history( t, false ); + + S_.i_in_ = B_.currents_.get_value( lag ); + + B_.logger_.record_data( t ); + } +} + +/* ---------------------------------------------------------------- + * Event handling functions + * ---------------------------------------------------------------- */ + +void +eprop_iaf_psc_delta_adapt::handle( SpikeEvent& e ) +{ + assert( e.get_delay_steps() > 0 ); + + B_.spikes_.add_value( + e.get_rel_delivery_steps( kernel().simulation_manager.get_slice_origin() ), e.get_weight() * e.get_multiplicity() ); +} + +void +eprop_iaf_psc_delta_adapt::handle( CurrentEvent& e ) +{ + assert( e.get_delay_steps() > 0 ); + + B_.currents_.add_value( + e.get_rel_delivery_steps( kernel().simulation_manager.get_slice_origin() ), e.get_weight() * e.get_current() ); +} + +void +eprop_iaf_psc_delta_adapt::handle( LearningSignalConnectionEvent& e ) +{ + for ( auto it_event = e.begin(); it_event != e.end(); ) + { + const long time_step = e.get_stamp().get_steps(); + const double weight = e.get_weight(); + const double error_signal = e.get_coeffvalue( it_event ); // get_coeffvalue advances iterator + const double learning_signal = weight * error_signal; + + write_learning_signal_to_history( time_step, learning_signal, false ); + } +} + +void +eprop_iaf_psc_delta_adapt::handle( DataLoggingRequest& e ) +{ + B_.logger_.handle( e ); +} + +void +eprop_iaf_psc_delta_adapt::compute_gradient( const long t_spike, + const long t_spike_previous, + double& z_previous_buffer, + double& z_bar, + double& e_bar, + double& e_bar_reg, + double& epsilon, + double& weight, + const CommonSynapseProperties& cp, + WeightOptimizer* optimizer ) +{ + double e = 0.0; // eligibility trace + double z = 0.0; // spiking variable + double z_current_buffer = 1.0; // buffer containing the spike that triggered the current integration + double psi = 0.0; // surrogate gradient + double L = 0.0; // learning signal + double firing_rate_reg = 0.0; // firing rate regularization + double grad = 0.0; // gradient + + const EpropSynapseCommonProperties& ecp = static_cast< const EpropSynapseCommonProperties& >( cp ); + const auto optimize_each_step = ( *ecp.optimizer_cp_ ).optimize_each_step_; + + auto eprop_hist_it = get_eprop_history( t_spike_previous - 1 ); + + const long t_compute_until = std::min( t_spike_previous + V_.eprop_isi_trace_cutoff_steps_, t_spike ); + + for ( long t = t_spike_previous; t < t_compute_until; ++t, ++eprop_hist_it ) + { + z = z_previous_buffer; + z_previous_buffer = z_current_buffer; + z_current_buffer = 0.0; + + psi = eprop_hist_it->surrogate_gradient_; + L = eprop_hist_it->learning_signal_; + firing_rate_reg = eprop_hist_it->firing_rate_reg_; + + z_bar = V_.P_v_m_ * z_bar + z; + e = psi * ( z_bar - P_.adapt_beta_ * epsilon ); + epsilon = V_.P_adapt_ * epsilon + e; + e_bar = P_.kappa_ * e_bar + ( 1.0 - P_.kappa_ ) * e; + e_bar_reg = P_.kappa_reg_ * e_bar_reg + ( 1.0 - P_.kappa_reg_ ) * e; + + if ( optimize_each_step ) + { + grad = L * e_bar + firing_rate_reg * e_bar_reg; + weight = optimizer->optimized_weight( *ecp.optimizer_cp_, t, grad, weight ); + } + else + { + grad += L * e_bar + firing_rate_reg * e_bar_reg; + } + } + + if ( not optimize_each_step ) + { + weight = optimizer->optimized_weight( *ecp.optimizer_cp_, t_compute_until, grad, weight ); + } + + const long cutoff_to_spike_interval = t_spike - t_compute_until; + + if ( cutoff_to_spike_interval > 0 ) + { + z_bar *= std::pow( V_.P_v_m_, cutoff_to_spike_interval ); + e_bar *= std::pow( P_.kappa_, cutoff_to_spike_interval ); + e_bar_reg *= std::pow( P_.kappa_reg_, cutoff_to_spike_interval ); + epsilon *= std::pow( V_.P_adapt_, cutoff_to_spike_interval ); + } +} + +} // namespace nest diff --git a/models/eprop_iaf_psc_delta_adapt.h b/models/eprop_iaf_psc_delta_adapt.h new file mode 100644 index 0000000000..d025e8404a --- /dev/null +++ b/models/eprop_iaf_psc_delta_adapt.h @@ -0,0 +1,728 @@ +/* + * eprop_iaf_psc_delta_adapt.h + * + * This file is part of NEST. + * + * Copyright (C) 2004 The NEST Initiative + * + * NEST is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * NEST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with NEST. If not, see . + * + */ + +#ifndef EPROP_IAF_PSC_DELTA_ADAPT_H +#define EPROP_IAF_PSC_DELTA_ADAPT_H + +// nestkernel +#include "connection.h" +#include "eprop_archiving_node.h" +#include "eprop_archiving_node_impl.h" +#include "eprop_synapse.h" +#include "event.h" +#include "nest_types.h" +#include "ring_buffer.h" +#include "universal_data_logger.h" + +namespace nest +{ + +/* BeginUserDocs: neuron, e-prop plasticity, current-based, integrate-and-fire, adaptive threshold + +Short description ++++++++++++++++++ + +Current-based leaky integrate-and-fire neuron model with delta-shaped +postsynaptic currents and threshold adaptation for e-prop plasticity + +Description ++++++++++++ + +``eprop_iaf_psc_delta_adapt`` is an implementation of a leaky integrate-and-fire +neuron model with delta-shaped postsynaptic currents and threshold adaptation +used for eligibility propagation (e-prop) plasticity. + +E-prop plasticity was originally introduced and implemented in TensorFlow in [1]_. + + .. note:: + The neuron dynamics of the ``eprop_iaf_psc_delta_adapt`` model (excluding + e-prop plasticity and the threshold adaptation) are similar to the neuron + dynamics of the ``iaf_psc_delta`` model, with minor differences, such as the + propagator of the post-synaptic current and the voltage reset upon a spike. + +The membrane voltage time course :math:`v_j^t` of the neuron :math:`j` is given by: + +.. math:: + v_j^t &= \alpha v_j^{t-1} + \sum_{i \neq j} W_{ji}^\text{rec} z_i^{t-1} + + \sum_i W_{ji}^\text{in} x_i^t \,, \\ + \alpha &= e^{ -\frac{ \Delta t }{ \tau_\text{m} } } \,, \\ + +where :math:`W_{ji}^\text{rec}` and :math:`W_{ji}^\text{in}` are the recurrent and +input synaptic weight matrices, and :math:`z_i^{t-1}` is the recurrent presynaptic +state variable, while :math:`x_i^t` represents the input at time :math:`t`. + +Descriptions of further parameters and variables can be found in the table below. + +The threshold adaptation is given by: + +.. math:: + A_j^t &= v_\text{th} + \beta a_j^t \,, \\ + a_j^t &= \rho a_j^{t-1} + z_j^{t-1} \,, \\ + \rho &= e^{-\frac{ \Delta t }{ \tau_\text{a} }} \,. \\ + +The spike state variable is expressed by a Heaviside function: + +.. math:: + z_j^t = H \left( v_j^t - A_j^t \right) \,. \\ + +If the membrane voltage crosses the adaptive threshold voltage :math:`A_j^t`, a spike is +emitted and the membrane voltage is reset to :math:`v_\text{reset}. After the time step +of the spike emission, the neuron is not able to spike for an absolute refractory period +:math:`t_\text{ref}` during which the membrane potential stays clamped to the reset voltage +:math:`v_\text{reset}`, thus + +.. math:: + v_m = v_\text{reset} \quad \text{for} \quad t_\text{spk} \leq t \leq t_\text{spk} + t_\text{ref} \,. + +Spikes arriving while the neuron is refractory are discarded by default. However, +if ``refractory_input`` is set to ``True`` they are damped for each time step +until the end of the refractory period and then added to the membrane voltage. + +An additional state variable and the corresponding differential equation +represents a piecewise constant external current. + +See the documentation on the :doc:`iaf_psc_delta<../models/iaf_psc_delta/>` neuron model +for more information on the integration of the subthreshold dynamics. + +The change of the synaptic weight is calculated from the gradient :math:`g^t` of +the loss :math:`E^t` with respect to the synaptic weight :math:`W_{ji}`: +:math:`\frac{ \text{d} E^t }{ \text{d} W_{ij} }` +which depends on the presynaptic +spikes :math:`z_i^{t-2}`, the surrogate gradient or pseudo-derivative +of the spike state variable with respect to the postsynaptic membrane +voltage :math:`\psi_j^{t-1}` (the product of which forms the eligibility +trace :math:`e_{ji}^{t-1}`), and the learning signal :math:`L_j^t` emitted +by the readout neurons. + +Surrogate gradients help overcome the challenge of the spiking function's +non-differentiability, facilitating the use of gradient-based learning +techniques such as e-prop. The non-existent derivative of the spiking +variable with respect to the membrane voltage, +:math:`\frac{\partial z^t_j}{ \partial v^t_j}`, can be effectively +replaced with a variety of surrogate gradient functions, as detailed in +various studies (see, e.g., [3]_). NEST currently provides four +different surrogate gradient functions: + +1. A piecewise linear function used among others in [1]_: + +.. math:: + \psi_j^t = \frac{ \gamma }{ v_\text{th} } \text{max} + \left( 0, 1-\beta \left| \frac{ v_j^t - v_\text{th} }{ v_\text{th} }\right| \right) \,. \\ + +2. An exponential function used in [4]_: + +.. math:: + \psi_j^t = \gamma \exp \left( -\beta \left| v_j^t - v_\text{th} \right| \right) \,. \\ + +3. The derivative of a fast sigmoid function used in [5]_: + +.. math:: + \psi_j^t = \gamma \left( 1 + \beta \left| v_j^t - v_\text{th} \right| \right)^2 \,. \\ + +4. An arctan function used in [6]_: + +.. math:: + \psi_j^t = \frac{\gamma}{\pi} \frac{1}{ 1 + \left( \beta \pi \left( v_j^t - v_\text{th} \right) \right)^2 } \,. \\ + +In the interval between two presynaptic spikes, the gradient is calculated +at each time step until the cutoff time point. This computation occurs over +the time range: + +:math:`t \in \left[ t_\text{spk,prev}, \min \left( t_\text{spk,prev} + \Delta t_\text{c}, t_\text{spk,curr} \right) +\right]`. + +Here, :math:`t_\text{spk,prev}` represents the time of the previous spike that +passed the synapse, while :math:`t_\text{spk,curr}` is the time of the +current spike, which triggers the application of the learning rule and the +subsequent synaptic weight update. The cutoff :math:`\Delta t_\text{c}` +defines the maximum allowable interval for integration between spikes. +The expression for the gradient is given by: + +.. math:: + \frac{ \text{d} E^t }{ \text{d} W_{ji} } &= L_j^t \bar{e}_{ji}^{t-1} \,, \\ + e_{ji}^{t-1} &= \psi_j^{t-1} \left( \bar{z}_i^{t-2} - \beta \epsilon_{ji,a}^{t-2} \right) \,, \\ + \epsilon^{t-2}_{ji,\text{a}} &= e_{ji}^{t-1} + \rho \epsilon_{ji,a}^{t-3} \,. \\ + +The eligibility trace and the presynaptic spike trains are low-pass filtered +with the following exponential kernels: + +.. math:: + \bar{e}_{ji}^t &= \mathcal{F}_\kappa \left( e_{ji}^t \right) + = \kappa \bar{e}_{ji}^{t-1} + \left( 1 - \kappa \right) e_{ji}^t \,, \\ + \bar{z}_i^t &= \mathcal{F}_\alpha \left( z_{i}^t \right)= \alpha \bar{z}_i^{t-1} + z_i^t \,. \\ + +Furthermore, a firing rate regularization mechanism keeps the exponential moving average of the postsynaptic +neuron's firing rate :math:`f_j^{\text{ema},t}` close to a target firing rate +:math:`f^\text{target}`. The gradient :math:`g_\text{reg}^t` of the regularization loss :math:`E_\text{reg}^t` +with respect to the synaptic weight :math:`W_{ji}` is given by: + +.. math:: + \frac{ \text{d} E_\text{reg}^t }{ \text{d} W_{ji}} + &\approx c_\text{reg} \left( f^{\text{ema},t}_j - f^\text{target} \right) \bar{e}_{ji}^t \,, \\ + f^{\text{ema},t}_j &= \mathcal{F}_{\kappa_\text{reg}} \left( \frac{z_j^t}{\Delta t} \right) + = \kappa_\text{reg} f^{\text{ema},t-1}_j + \left( 1 - \kappa_\text{reg} \right) \frac{z_j^t}{\Delta t} \,, \\ + +where :math:`c_\text{reg}` is a constant scaling factor. + +The overall gradient is given by the addition of the two gradients. + +As a last step for every round in the loop over the time steps :math:`t`, the new weight is retrieved by feeding the +current gradient :math:`g^t` to the optimizer (see :doc:`weight_optimizer<../models/weight_optimizer/>` +for more information on the available optimizers): + +.. math:: + w^t = \text{optimizer} \left( t, g^t, w^{t-1} \right) \,. \\ + +After the loop has terminated, the filtered dynamic variables of e-prop are propagated from the end of the cutoff until +the next spike: + +.. math:: + p &= \text{max} \left( 0, t_\text{s}^{t} - \left( t_\text{s}^{t-1} + {\Delta t}_\text{c} \right) \right) \,, \\ + \bar{e}_{ji}^{t+p} &= \bar{e}_{ji}^t \kappa^p \,, \\ + \bar{z}_i^{t+p} &= \bar{z}_i^t \alpha^p \,, \\ + \epsilon^{t+p} &= \epsilon^t \rho^p \,. \\ + +For more information on the implementation details of the neuron model, see [7]_ and [8]_. + +For more information on e-prop plasticity, see the documentation on the other e-prop models: + + * :doc:`eprop_iaf_psc_delta<../models/eprop_iaf_psc_delta/>` + * :doc:`eprop_readout<../models/eprop_readout/>` + * :doc:`eprop_synapse<../models/eprop_synapse/>` + * :doc:`eprop_learning_signal_connection<../models/eprop_learning_signal_connection/>` + +Details on the event-based NEST implementation of e-prop can be found in [2]_. + +Parameters +++++++++++ + +The following parameters can be set in the status dictionary. + +=========================== ======= ======================= ================ =================================== +**Neuron parameters** +---------------------------------------------------------------------------------------------------------------- +Parameter Unit Math equivalent Default Description +=========================== ======= ======================= ================ =================================== +``C_m`` pF :math:`C_\text{m}` 250.0 Capacitance of the membrane +``E_L`` mV :math:`E_\text{L}` -70.0 Leak / resting membrane potential +``I_e`` pA :math:`I_\text{e}` 0.0 Constant external input current +``t_ref`` ms :math:`t_\text{ref}` 2.0 Duration of the refractory period +``tau_m`` ms :math:`\tau_\text{m}` 10.0 Time constant of the membrane +``V_min`` mV :math:`v_\text{min}` negative maximum Absolute lower bound of the + value membrane voltage + representable + by a ``double`` + type in C++ +``V_th`` mV :math:`v_\text{th}` -55.0 Spike threshold voltage +``V_reset`` mV :math:`v_\text{reset}` -70.0 Reset voltage +``refractory_input`` Boolean ``False`` If ``True``, spikes arriving during + the refractory period are damped + until it ends and then added to the + membrane voltage +``adapt_beta`` :math:`\beta` 1.0 Prefactor of the threshold + adaptation +``adapt_tau`` ms :math:`\tau_\text{a}` 10.0 Time constant of the threshold + adaptation +=========================== ======= ======================= ================ =================================== + +=============================== ======= =========================== ================== ========================= +**E-prop parameters** +---------------------------------------------------------------------------------------------------------------- +Parameter Unit Math equivalent Default Description +=============================== ======= =========================== ================== ========================= +``c_reg`` :math:`c_\text{reg}` 0.0 Coefficient of firing + rate regularization +``eprop_isi_trace_cutoff`` ms :math:`{\Delta t}_\text{c}` maximum value Cutoff for integration of + representable e-prop update between two + by a ``long`` spikes + type in C++ +``f_target`` Hz :math:`f^\text{target}` 10.0 Target firing rate of + rate regularization +``kappa`` :math:`\kappa` 0.97 Low-pass filter of the + eligibility trace +``kappa_reg`` :math:`\kappa_\text{reg}` 0.97 Low-pass filter of the + firing rate for + regularization +``beta`` :math:`\beta` 1.0 Width scaling of + surrogate gradient / + pseudo-derivative of + membrane voltage +``gamma`` :math:`\gamma` 0.3 Height scaling of + surrogate gradient / + pseudo-derivative of + membrane voltage +``surrogate_gradient_function`` :math:`\psi` "piecewise_linear" Surrogate gradient / + pseudo-derivative + function + ["piecewise_linear", + "exponential", + "fast_sigmoid_derivative" + , "arctan"] +=============================== ======= =========================== ================== ========================= + +Recordables ++++++++++++ + +The following state variables evolve during simulation and can be recorded. + +================== ==== =============== ============= ======================== +**Neuron state variables and recordables** +------------------------------------------------------------------------------ +State variable Unit Math equivalent Initial value Description +================== ==== =============== ============= ======================== +``adaptation`` :math:`a_j` 0.0 Adaptation variable +``V_m`` mV :math:`v_j` -70.0 Membrane voltage +``V_th_adapt`` mV :math:`A_j` -55.0 Adapting spike threshold +================== ==== =============== ============= ======================== + +====================== ==== =============== ============= ========================================= +**E-prop state variables and recordables** +--------------------------------------------------------------------------------------------------- +State variable Unit Math equivalent Initial value Description +====================== ==== =============== ============= ========================================= +``learning_signal`` pA :math:`L_j` 0.0 Learning signal +``surrogate_gradient`` :math:`\psi_j` 0.0 Surrogate gradient / pseudo-derivative of + membrane voltage +====================== ==== =============== ============= ========================================= + +Usage ++++++ + +This model can only be used in combination with the other e-prop models +and the network architecture requires specific wiring, input, and output. +The usage is demonstrated in several +:doc:`supervised regression and classification tasks <../auto_examples/eprop_plasticity/index>` +reproducing among others the original proof-of-concept tasks in [1]_. + +References +++++++++++ + +.. [1] Bellec G, Scherr F, Subramoney F, Hajek E, Salaj D, Legenstein R, + Maass W (2020). A solution to the learning dilemma for recurrent + networks of spiking neurons. Nature Communications, 11:3625. + https://doi.org/10.1038/s41467-020-17236-y + +.. [2] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Plesser HE, + Dahmen D, Bolten M, Van Albada SJ*, Diesmann M*. Event-based + implementation of eligibility propagation (in preparation) + +.. [3] Neftci EO, Mostafa H, Zenke F (2019). Surrogate Gradient Learning in + Spiking Neural Networks. IEEE Signal Processing Magazine, 36(6), 51-63. + https://doi.org/10.1109/MSP.2019.2931595 + +.. [4] Shrestha SB, Orchard G (2018). SLAYER: Spike Layer Error Reassignment in + Time. Advances in Neural Information Processing Systems, 31:1412-1421. + https://proceedings.neurips.cc/paper_files/paper/2018/hash/82.. rubric:: References + +.. [5] Zenke F, Ganguli S (2018). SuperSpike: Supervised Learning in Multilayer + Spiking Neural Networks. Neural Computation, 30:1514–1541. + https://doi.org/10.1162/neco_a_01086 + +.. [6] Fang W, Yu Z, Chen Y, Huang T, Masquelier T, Tian Y (2021). Deep residual + learning in spiking neural networks. Advances in Neural Information + Processing Systems, 34:21056–21069. + https://proceedings.neurips.cc/paper/2021/hash/afe434653a898da20044041262b3ac74-Abstract.html + +.. [7] Rotter S, Diesmann M (1999). Exact simulation of time-invariant linear + systems with applications to neuronal modeling. Biological Cybernetics + 81:381-402. + https://doi.org/10.1007/s004220050570 + +.. [8] Diesmann M, Gewaltig MO, Rotter S, Aertsen A (2001). State space analysis + of synchronous spiking in cortical neural networks. Neurocomputing + 38-40:565-571. + https://doi.org/10.1016/S0925-2312(01)00409-X + +Sends ++++++ + +SpikeEvent + +Receives +++++++++ + +SpikeEvent, CurrentEvent, LearningSignalConnectionEvent, DataLoggingRequest + +See also +++++++++ + +Examples using this model ++++++++++++++++++++++++++ + +.. listexamples:: eprop_iaf_psc_delta_adapt + +EndUserDocs */ + +void register_eprop_iaf_psc_delta_adapt( const std::string& name ); + +/** + * @brief Class implementing a LIF neuron model for e-prop plasticity with additional biological features. + * + * Class implementing a current-based leaky integrate-and-fire neuron model with delta-shaped postsynaptic currents for + * e-prop plasticity according to Bellec et al. (2020) with additional biological features described in + * Korcsak-Gorzo, Stapmanns, and Espinoza Valverde et al. (in preparation). + */ +class eprop_iaf_psc_delta_adapt : public EpropArchivingNodeRecurrent +{ + +public: + //! Default constructor. + eprop_iaf_psc_delta_adapt(); + + //! Copy constructor. + eprop_iaf_psc_delta_adapt( const eprop_iaf_psc_delta_adapt& ); + + using Node::handle; + using Node::handles_test_event; + + size_t send_test_event( Node&, size_t, synindex, bool ) override; + + void handle( SpikeEvent& ) override; + void handle( CurrentEvent& ) override; + void handle( LearningSignalConnectionEvent& ) override; + void handle( DataLoggingRequest& ) override; + + size_t handles_test_event( SpikeEvent&, size_t ) override; + size_t handles_test_event( CurrentEvent&, size_t ) override; + size_t handles_test_event( LearningSignalConnectionEvent&, size_t ) override; + size_t handles_test_event( DataLoggingRequest&, size_t ) override; + + void get_status( DictionaryDatum& ) const override; + void set_status( const DictionaryDatum& ) override; + +private: + void init_buffers_() override; + void pre_run_hook() override; + + void update( Time const&, const long, const long ) override; + + void compute_gradient( const long, + const long, + double&, + double&, + double&, + double&, + double&, + double&, + const CommonSynapseProperties&, + WeightOptimizer* ) override; + + long get_shift() const override; + bool is_eprop_recurrent_node() const override; + long get_eprop_isi_trace_cutoff() const override; + + //! Pointer to member function selected for computing the surrogate gradient. + surrogate_gradient_function compute_surrogate_gradient_; + + //! Map for storing a static set of recordables. + friend class RecordablesMap< eprop_iaf_psc_delta_adapt >; + + //! Logger for universal data supporting the data logging request / reply mechanism. Populated with a recordables map. + friend class UniversalDataLogger< eprop_iaf_psc_delta_adapt >; + + //! Structure of parameters. + struct Parameters_ + { + //! Time constant of the membrane (ms). + double tau_m_; + + //! Capacitance of the membrane (pF). + double C_m_; + + //! Duration of the refractory period (ms). + double t_ref_; + + //! Leak / resting membrane potential (mV). + double E_L_; + + //! Constant external input current (pA). + double I_e_; + + //! Spike threshold voltage relative to the leak membrane potential (mV). + double V_th_; + + //! Absolute lower bound of the membrane voltage relative to the leak membrane potential (mV). + double V_min_; + + //! Reset voltage relative to the leak membrane potential (mV). + double V_reset_; + + //! If True, count spikes arriving during the refractory period. + bool with_refr_input_; + + //! Prefactor of the threshold adaptation. + double adapt_beta_; + + //! Time constant of the threshold adaptation (ms). + double adapt_tau_; + + + //! Coefficient of firing rate regularization. + double c_reg_; + + //! Target firing rate of rate regularization (spikes/s). + double f_target_; + + //! Width scaling of surrogate gradient / pseudo-derivative of membrane voltage. + double beta_; + + //! Height scaling of surrogate gradient / pseudo-derivative of membrane voltage. + double gamma_; + + //! Surrogate gradient / pseudo-derivative function of the membrane voltage ["piecewise_linear", "exponential", + //! "fast_sigmoid_derivative", "arctan"] + std::string surrogate_gradient_function_; + + //! Low-pass filter of the eligibility trace. + double kappa_; + + //! Low-pass filter of the firing rate for regularization. + double kappa_reg_; + + //! Time interval from the previous spike until the cutoff of e-prop update integration between two spikes (ms). + double eprop_isi_trace_cutoff_; + + //! Default constructor. + Parameters_(); + + //! Get the parameters and their values. + void get( DictionaryDatum& ) const; + + //! Set the parameters and throw errors in case of invalid values. + double set( const DictionaryDatum&, Node* ); + }; + + //! Structure of state variables. + struct State_ + { + //! Input current (pA). + double i_in_; + + //! Membrane voltage relative to the leak membrane potential (mV). + double v_m_; + + //! Number of remaining refractory steps. + int r_; + + //! Count of spikes arriving during refractory period discounted for decay until end of refractory period. + double refr_spikes_buffer_; + + //! Binary spike state variable - 1.0 if the neuron has spiked in the previous time step and 0.0 otherwise. + double z_; + + //! Adaptation variable. + double adapt_; + + //! Adapting spike threshold voltage. + double v_th_adapt_; + + //! Learning signal. Sum of weighted error signals coming from the readout neurons. + double learning_signal_; + + //! Surrogate gradient / pseudo-derivative of the membrane voltage. + double surrogate_gradient_; + + //! Default constructor. + State_(); + + //! Get the state variables and their values. + void get( DictionaryDatum&, const Parameters_& ) const; + + //! Set the state variables. + void set( const DictionaryDatum&, const Parameters_&, double, Node* ); + }; + + //! Structure of buffers. + struct Buffers_ + { + //! Default constructor. + Buffers_( eprop_iaf_psc_delta_adapt& ); + + //! Copy constructor. + Buffers_( const Buffers_&, eprop_iaf_psc_delta_adapt& ); + + //! Buffer for incoming spikes. + RingBuffer spikes_; + + //! Buffer for incoming currents. + RingBuffer currents_; + + //! Logger for universal data. + UniversalDataLogger< eprop_iaf_psc_delta_adapt > logger_; + }; + + //! Structure of internal variables. + struct Variables_ + { + //! Propagator matrix entry for evolving the membrane voltage (mathematical symbol "alpha" in user documentation). + double P_v_m_; + + //! Propagator matrix entry for evolving the incoming currents. + double P_i_in_; + + //! Propagator matrix entry for evolving the adaptation (mathematical symbol "rho" in user documentation). + double P_adapt_; + + //! Total refractory steps. + int RefractoryCounts_; + + //! Time steps from the previous spike until the cutoff of e-prop update integration between two spikes. + long eprop_isi_trace_cutoff_steps_; + }; + + //! Get the current value of the membrane voltage. + double + get_v_m_() const + { + return S_.v_m_ + P_.E_L_; + } + + //! Get the current value of the surrogate gradient. + double + get_surrogate_gradient_() const + { + return S_.surrogate_gradient_; + } + + //! Get the current value of the learning signal. + double + get_learning_signal_() const + { + return S_.learning_signal_; + } + + //! Get the current value of the adapting threshold. + double + get_v_th_adapt_() const + { + return S_.v_th_adapt_ + P_.E_L_; + } + + //! Get the current value of the adaptation. + double + get_adaptation_() const + { + return S_.adapt_; + } + + // the order in which the structure instances are defined is important for speed + + //! Structure of parameters. + Parameters_ P_; + + //! Structure of state variables. + State_ S_; + + //! Structure of internal variables. + Variables_ V_; + + //! Structure of buffers. + Buffers_ B_; + + //! Map storing a static set of recordables. + static RecordablesMap< eprop_iaf_psc_delta_adapt > recordablesMap_; +}; + +inline long +eprop_iaf_psc_delta_adapt::get_eprop_isi_trace_cutoff() const +{ + return V_.eprop_isi_trace_cutoff_steps_; +} + +inline size_t +eprop_iaf_psc_delta_adapt::send_test_event( Node& target, size_t receptor_type, synindex, bool ) +{ + SpikeEvent e; + e.set_sender( *this ); + return target.handles_test_event( e, receptor_type ); +} + +inline size_t +eprop_iaf_psc_delta_adapt::handles_test_event( SpikeEvent&, size_t receptor_type ) +{ + if ( receptor_type != 0 ) + { + throw UnknownReceptorType( receptor_type, get_name() ); + } + + return 0; +} + +inline size_t +eprop_iaf_psc_delta_adapt::handles_test_event( CurrentEvent&, size_t receptor_type ) +{ + if ( receptor_type != 0 ) + { + throw UnknownReceptorType( receptor_type, get_name() ); + } + + return 0; +} + +inline size_t +eprop_iaf_psc_delta_adapt::handles_test_event( LearningSignalConnectionEvent&, size_t receptor_type ) +{ + if ( receptor_type != 0 ) + { + throw UnknownReceptorType( receptor_type, get_name() ); + } + + return 0; +} + +inline size_t +eprop_iaf_psc_delta_adapt::handles_test_event( DataLoggingRequest& dlr, size_t receptor_type ) +{ + if ( receptor_type != 0 ) + { + throw UnknownReceptorType( receptor_type, get_name() ); + } + + return B_.logger_.connect_logging_device( dlr, recordablesMap_ ); +} + +inline void +eprop_iaf_psc_delta_adapt::get_status( DictionaryDatum& d ) const +{ + P_.get( d ); + S_.get( d, P_ ); + ( *d )[ names::recordables ] = recordablesMap_.get_list(); +} + +inline void +eprop_iaf_psc_delta_adapt::set_status( const DictionaryDatum& d ) +{ + // temporary copies in case of errors + Parameters_ ptmp = P_; + State_ stmp = S_; + + // make sure that ptmp and stmp consistent - throw BadProperty if not + const double delta_EL = ptmp.set( d, this ); + stmp.set( d, ptmp, delta_EL, this ); + + P_ = ptmp; + S_ = stmp; +} + +} // namespace nest + +#endif // EPROP_IAF_PSC_DELTA_ADAPT_H diff --git a/models/eprop_learning_signal_connection.cpp b/models/eprop_learning_signal_connection.cpp new file mode 100644 index 0000000000..7e537ed5c0 --- /dev/null +++ b/models/eprop_learning_signal_connection.cpp @@ -0,0 +1,32 @@ +/* + * eprop_learning_signal_connection.cpp + * + * This file is part of NEST. + * + * Copyright (C) 2004 The NEST Initiative + * + * NEST is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * NEST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with NEST. If not, see . + * + */ + +#include "eprop_learning_signal_connection.h" + +// nestkernel +#include "nest_impl.h" + +void +nest::register_eprop_learning_signal_connection( const std::string& name ) +{ + register_connection_model< eprop_learning_signal_connection >( name ); +} diff --git a/models/eprop_learning_signal_connection.h b/models/eprop_learning_signal_connection.h new file mode 100644 index 0000000000..8929993b23 --- /dev/null +++ b/models/eprop_learning_signal_connection.h @@ -0,0 +1,231 @@ +/* + * eprop_learning_signal_connection.h + * + * This file is part of NEST. + * + * Copyright (C) 2004 The NEST Initiative + * + * NEST is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * NEST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with NEST. If not, see . + * + */ + + +#ifndef EPROP_LEARNING_SIGNAL_CONNECTION_H +#define EPROP_LEARNING_SIGNAL_CONNECTION_H + +// nestkernel +#include "connection.h" + +namespace nest +{ + +/* BeginUserDocs: synapse, e-prop plasticity + +Short description ++++++++++++++++++ + +Synapse model transmitting feedback learning signals for e-prop plasticity + +Description ++++++++++++ + +``eprop_learning_signal_connection`` is an implementation of a feedback connector from +``eprop_readout`` readout neurons to ``eprop_iaf`` or ``eprop_iaf_adapt`` +recurrent neurons that transmits the learning signals :math:`L_j^t` for eligibility propagation (e-prop) plasticity and +has a static weight :math:`B_{jk}`. + +E-prop plasticity was originally introduced and implemented in TensorFlow in [1]_. + +For more information on e-prop plasticity, see the documentation on the other e-prop models: + + * :doc:`eprop_iaf<../models/eprop_iaf/>` + * :doc:`eprop_iaf_adapt<../models/eprop_iaf_adapt/>` + * :doc:`eprop_readout<../models/eprop_readout/>` + * :doc:`eprop_synapse<../models/eprop_synapse/>` + +Details on the event-based NEST implementation of e-prop can be found in [2]_. + +Parameters +++++++++++ + +The following parameters can be set in the status dictionary. + +========== ===== ================ ======= =============== +**Individual synapse parameters** +--------------------------------------------------------- +Parameter Unit Math equivalent Default Description +========== ===== ================ ======= =============== +``delay`` ms :math:`d_{jk}` 1.0 Dendritic delay +``weight`` pA :math:`B_{jk}` 1.0 Synaptic weight +========== ===== ================ ======= =============== + +Recordables ++++++++++++ + +The following variables can be recorded. Note that since this connection lacks +a plasticity mechanism the weight does not evolve over time. + +============== ==== =============== ============= =============== +**Synapse recordables** +----------------------------------------------------------------- +State variable Unit Math equivalent Initial value Description +============== ==== =============== ============= =============== +``weight`` pA :math:`B_{jk}` 1.0 Synaptic weight +============== ==== =============== ============= =============== + +Usage ++++++ + +This model can only be used in combination with the other e-prop models +and the network architecture requires specific wiring, input, and output. +The usage is demonstrated in several +:doc:`supervised regression and classification tasks <../auto_examples/eprop_plasticity/index>` +reproducing among others the original proof-of-concept tasks in [1]_. + +Transmits ++++++++++ + +LearningSignalConnectionEvent + +References +++++++++++ + +.. [1] Bellec G, Scherr F, Subramoney F, Hajek E, Salaj D, Legenstein R, + Maass W (2020). A solution to the learning dilemma for recurrent + networks of spiking neurons. Nature Communications, 11:3625. + https://doi.org/10.1038/s41467-020-17236-y + +.. [2] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Plesser HE, + Dahmen D, Bolten M, Van Albada SJ*, Diesmann M*. Event-based + implementation of eligibility propagation (in preparation) + +See also +++++++++ + +Examples using this model ++++++++++++++++++++++++++ + +.. listexamples:: eprop_learning_signal_connection + +EndUserDocs */ + +void register_eprop_learning_signal_connection( const std::string& name ); + +/** + * @brief Class implementing a feedback connection model for e-prop plasticity with additional biological features. + * + * Class implementing a synapse model transmitting secondary feedback learning signals for e-prop plasticity + * according to Bellec et al. (2020) with additional biological features described in + * Korcsak-Gorzo, Stapmanns, and Espinoza Valverde et al. (in preparation). + */ +template < typename targetidentifierT > +class eprop_learning_signal_connection : public Connection< targetidentifierT > +{ + +public: + //! Type of the common synapse properties. + typedef CommonSynapseProperties CommonPropertiesType; + + //! Type of the connection base. + typedef Connection< targetidentifierT > ConnectionBase; + + //! Properties of the connection model. + static constexpr ConnectionModelProperties properties = ConnectionModelProperties::HAS_DELAY; + + //! Default constructor. + eprop_learning_signal_connection() + : ConnectionBase() + , weight_( 1.0 ) + { + } + + //! Get the secondary learning signal event. + SecondaryEvent* get_secondary_event(); + + using ConnectionBase::get_delay_steps; + using ConnectionBase::get_rport; + using ConnectionBase::get_target; + + //! Check if the target accepts the event and receptor type requested by the sender. + void + check_connection( Node& s, Node& t, size_t receptor_type, const CommonPropertiesType& ) + { + LearningSignalConnectionEvent ge; + + s.sends_secondary_event( ge ); + ge.set_sender( s ); + Connection< targetidentifierT >::target_.set_rport( t.handles_test_event( ge, receptor_type ) ); + Connection< targetidentifierT >::target_.set_target( &t ); + } + + //! Send the learning signal event. + bool + send( Event& e, size_t t, const CommonSynapseProperties& ) + { + e.set_weight( weight_ ); + e.set_delay_steps( get_delay_steps() ); + e.set_receiver( *get_target( t ) ); + e.set_rport( get_rport() ); + e(); + return true; + } + + //! Get the model attributes and their values. + void get_status( DictionaryDatum& d ) const; + + //! Set the values of the model attributes. + void set_status( const DictionaryDatum& d, ConnectorModel& cm ); + + //! Set the synaptic weight to the provided value. + void + set_weight( const double w ) + { + weight_ = w; + } + +private: + //! Synaptic weight. + double weight_; +}; + +template < typename targetidentifierT > +constexpr ConnectionModelProperties eprop_learning_signal_connection< targetidentifierT >::properties; + +template < typename targetidentifierT > +void +eprop_learning_signal_connection< targetidentifierT >::get_status( DictionaryDatum& d ) const +{ + ConnectionBase::get_status( d ); + def< double >( d, names::weight, weight_ ); + def< long >( d, names::size_of, sizeof( *this ) ); +} + +template < typename targetidentifierT > +void +eprop_learning_signal_connection< targetidentifierT >::set_status( const DictionaryDatum& d, ConnectorModel& cm ) +{ + ConnectionBase::set_status( d, cm ); + updateValue< double >( d, names::weight, weight_ ); +} + +template < typename targetidentifierT > +SecondaryEvent* +eprop_learning_signal_connection< targetidentifierT >::get_secondary_event() +{ + return new LearningSignalConnectionEvent(); +} + +} // namespace nest + +#endif // EPROP_LEARNING_SIGNAL_CONNECTION_H diff --git a/models/eprop_learning_signal_connection_bsshslm_2020.h b/models/eprop_learning_signal_connection_bsshslm_2020.h index 98ad6687cf..c48c4ee1b6 100644 --- a/models/eprop_learning_signal_connection_bsshslm_2020.h +++ b/models/eprop_learning_signal_connection_bsshslm_2020.h @@ -65,27 +65,34 @@ Parameters The following parameters can be set in the status dictionary. -========= ===== ================ ======= =============== +========== ===== ================ ======= =============== **Individual synapse parameters** --------------------------------------------------------- -Parameter Unit Math equivalent Default Description -========= ===== ================ ======= =============== -delay ms :math:`d_{jk}` 1.0 Dendritic delay -weight pA :math:`B_{jk}` 1.0 Synaptic weight -========= ===== ================ ======= =============== +--------------------------------------------------------- +Parameter Unit Math equivalent Default Description +========== ===== ================ ======= =============== +``delay`` ms :math:`d_{jk}` 1.0 Dendritic delay +``weight`` pA :math:`B_{jk}` 1.0 Synaptic weight +========== ===== ================ ======= =============== Recordables +++++++++++ -The following variables can be recorded: +The following variables can be recorded. Note that since this connection lacks +a plasticity mechanism the weight does not evolve over time. - - synaptic weight ``weight`` +============== ==== =============== ============= =============== +**Synapse recordables** +----------------------------------------------------------------- +State variable Unit Math equivalent Initial value Description +============== ==== =============== ============= =============== +``weight`` pA :math:`B_{jk}` 1.0 Synaptic weight +============== ==== =============== ============= =============== Usage +++++ -This model can only be used in combination with the other e-prop models, -whereby the network architecture requires specific wiring, input, and output. +This model can only be used in combination with the other e-prop models +and the network architecture requires specific wiring, input, and output. The usage is demonstrated in several :doc:`supervised regression and classification tasks <../auto_examples/eprop_plasticity/index>` reproducing among others the original proof-of-concept tasks in [1]_. @@ -102,15 +109,16 @@ References Maass W (2020). A solution to the learning dilemma for recurrent networks of spiking neurons. Nature Communications, 11:3625. https://doi.org/10.1038/s41467-020-17236-y -.. [2] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Dahmen D, - van Albada SJ, Bolten M, Diesmann M. Event-based implementation of - eligibility propagation (in preparation) + +.. [2] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Plesser HE, + Dahmen D, Bolten M, Van Albada SJ*, Diesmann M*. Event-based + implementation of eligibility propagation (in preparation) See also ++++++++ Examples using this model -++++++++++++++++++++++++++ ++++++++++++++++++++++++++ .. listexamples:: eprop_learning_signal_connection_bsshslm_2020 @@ -119,6 +127,8 @@ EndUserDocs */ void register_eprop_learning_signal_connection_bsshslm_2020( const std::string& name ); /** + * @brief Class implementing a feedback connection model for e-prop plasticity. + * * Class implementing a synapse model transmitting secondary feedback learning signals for e-prop plasticity * according to Bellec et al. (2020). */ diff --git a/models/eprop_readout.cpp b/models/eprop_readout.cpp new file mode 100644 index 0000000000..b52d9722d7 --- /dev/null +++ b/models/eprop_readout.cpp @@ -0,0 +1,390 @@ +/* + * eprop_readout.cpp + * + * This file is part of NEST. + * + * Copyright (C) 2004 The NEST Initiative + * + * NEST is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * NEST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with NEST. If not, see . + * + */ + +// nest models +#include "eprop_readout.h" + +// C++ +#include + +// libnestutil +#include "dict_util.h" +#include "numerics.h" + +// nestkernel +#include "exceptions.h" +#include "kernel_manager.h" +#include "nest_impl.h" +#include "universal_data_logger_impl.h" + +// sli +#include "dictutils.h" + +namespace nest +{ + +void +register_eprop_readout( const std::string& name ) +{ + register_node_model< eprop_readout >( name ); +} + +/* ---------------------------------------------------------------- + * Recordables map + * ---------------------------------------------------------------- */ + +RecordablesMap< eprop_readout > eprop_readout::recordablesMap_; + +template <> +void +RecordablesMap< eprop_readout >::create() +{ + insert_( names::error_signal, &eprop_readout::get_error_signal_ ); + insert_( names::readout_signal, &eprop_readout::get_readout_signal_ ); + insert_( names::target_signal, &eprop_readout::get_target_signal_ ); + insert_( names::V_m, &eprop_readout::get_v_m_ ); +} + +/* ---------------------------------------------------------------- + * Default constructors for parameters, state, and buffers + * ---------------------------------------------------------------- */ + +eprop_readout::Parameters_::Parameters_() + : C_m_( 250.0 ) + , E_L_( 0.0 ) + , I_e_( 0.0 ) + , regular_spike_arrival_( true ) + , tau_m_( 10.0 ) + , V_min_( -std::numeric_limits< double >::max() ) + , eprop_isi_trace_cutoff_( 1000.0 ) +{ +} + +eprop_readout::State_::State_() + : error_signal_( 0.0 ) + , readout_signal_( 0.0 ) + , target_signal_( 0.0 ) + , i_in_( 0.0 ) + , v_m_( 0.0 ) + , z_in_( 0.0 ) +{ +} + +eprop_readout::Buffers_::Buffers_( eprop_readout& n ) + : logger_( n ) +{ +} + +eprop_readout::Buffers_::Buffers_( const Buffers_&, eprop_readout& n ) + : logger_( n ) +{ +} + +/* ---------------------------------------------------------------- + * Getter and setter functions for parameters and state + * ---------------------------------------------------------------- */ + +void +eprop_readout::Parameters_::get( DictionaryDatum& d ) const +{ + def< double >( d, names::C_m, C_m_ ); + def< double >( d, names::E_L, E_L_ ); + def< double >( d, names::I_e, I_e_ ); + def< bool >( d, names::regular_spike_arrival, regular_spike_arrival_ ); + def< double >( d, names::tau_m, tau_m_ ); + def< double >( d, names::V_min, V_min_ + E_L_ ); + def< double >( d, names::eprop_isi_trace_cutoff, eprop_isi_trace_cutoff_ ); +} + +double +eprop_readout::Parameters_::set( const DictionaryDatum& d, Node* node ) +{ + // if leak potential is changed, adjust all variables defined relative to it + const double ELold = E_L_; + updateValueParam< double >( d, names::E_L, E_L_, node ); + const double delta_EL = E_L_ - ELold; + + V_min_ -= updateValueParam< double >( d, names::V_min, V_min_, node ) ? E_L_ : delta_EL; + + updateValueParam< double >( d, names::C_m, C_m_, node ); + updateValueParam< double >( d, names::I_e, I_e_, node ); + updateValueParam< bool >( d, names::regular_spike_arrival, regular_spike_arrival_, node ); + updateValueParam< double >( d, names::tau_m, tau_m_, node ); + updateValueParam< double >( d, names::eprop_isi_trace_cutoff, eprop_isi_trace_cutoff_, node ); + + if ( C_m_ <= 0 ) + { + throw BadProperty( "Membrane capacitance C_m > 0 required." ); + } + + if ( tau_m_ <= 0 ) + { + throw BadProperty( "Membrane time constant tau_m > 0 required." ); + } + + if ( eprop_isi_trace_cutoff_ < 0.0 ) + { + throw BadProperty( "Cutoff of integration of eprop trace between spikes eprop_isi_trace_cutoff ≥ 0 required." ); + } + + return delta_EL; +} + +void +eprop_readout::State_::get( DictionaryDatum& d, const Parameters_& p ) const +{ + def< double >( d, names::V_m, v_m_ + p.E_L_ ); + def< double >( d, names::error_signal, error_signal_ ); + def< double >( d, names::readout_signal, readout_signal_ ); + def< double >( d, names::target_signal, target_signal_ ); +} + +void +eprop_readout::State_::set( const DictionaryDatum& d, const Parameters_& p, double delta_EL, Node* node ) +{ + v_m_ -= updateValueParam< double >( d, names::V_m, v_m_, node ) ? p.E_L_ : delta_EL; +} + +/* ---------------------------------------------------------------- + * Default and copy constructor for node + * ---------------------------------------------------------------- */ + +eprop_readout::eprop_readout() + : EpropArchivingNodeReadout() + , P_() + , S_() + , B_( *this ) +{ + recordablesMap_.create(); +} + +eprop_readout::eprop_readout( const eprop_readout& n ) + : EpropArchivingNodeReadout( n ) + , P_( n.P_ ) + , S_( n.S_ ) + , B_( n.B_, *this ) +{ +} + +/* ---------------------------------------------------------------- + * Node initialization functions + * ---------------------------------------------------------------- */ + +void +eprop_readout::init_buffers_() +{ + B_.spikes_.clear(); // includes resize + B_.currents_.clear(); // includes resize + B_.logger_.reset(); // includes resize +} + +void +eprop_readout::pre_run_hook() +{ + B_.logger_.init(); // ensures initialization in case multimeter connected after Simulate + + V_.eprop_isi_trace_cutoff_steps_ = Time( Time::ms( P_.eprop_isi_trace_cutoff_ ) ).get_steps(); + + compute_error_signal = &eprop_readout::compute_error_signal_mean_squared_error; + + const double dt = Time::get_resolution().get_ms(); + + V_.P_v_m_ = std::exp( -dt / P_.tau_m_ ); + V_.P_i_in_ = P_.tau_m_ / P_.C_m_ * ( 1.0 - V_.P_v_m_ ); + V_.P_z_in_ = P_.regular_spike_arrival_ ? 1.0 : 1.0 - V_.P_v_m_; +} + +long +eprop_readout::get_shift() const +{ + return offset_gen_ + delay_in_rec_; +} + +bool +eprop_readout::is_eprop_recurrent_node() const +{ + return false; +} + +/* ---------------------------------------------------------------- + * Update function + * ---------------------------------------------------------------- */ + +void +eprop_readout::update( Time const& origin, const long from, const long to ) +{ + const size_t buffer_size = kernel().connection_manager.get_min_delay(); + + std::vector< double > error_signal_buffer( buffer_size, 0.0 ); + + for ( long lag = from; lag < to; ++lag ) + { + const long t = origin.get_steps() + lag; + + S_.z_in_ = B_.spikes_.get_value( lag ); + + S_.v_m_ = V_.P_i_in_ * S_.i_in_ + V_.P_z_in_ * S_.z_in_ + V_.P_v_m_ * S_.v_m_; + S_.v_m_ = std::max( S_.v_m_, P_.V_min_ ); + + ( this->*compute_error_signal )( lag ); + + S_.target_signal_ *= S_.learning_window_signal_; + S_.readout_signal_ *= S_.learning_window_signal_; + S_.error_signal_ *= S_.learning_window_signal_; + + error_signal_buffer[ lag ] = S_.error_signal_; + + append_new_eprop_history_entry( t, false ); + write_error_signal_to_history( t, S_.error_signal_, false ); + + S_.i_in_ = B_.currents_.get_value( lag ) + P_.I_e_; + + B_.logger_.record_data( t ); + } + + LearningSignalConnectionEvent error_signal_event; + error_signal_event.set_coeffarray( error_signal_buffer ); + kernel().event_delivery_manager.send_secondary( *this, error_signal_event ); + + return; +} + +/* ---------------------------------------------------------------- + * Error signal functions + * ---------------------------------------------------------------- */ + +void +eprop_readout::compute_error_signal_mean_squared_error( const long lag ) +{ + S_.readout_signal_ = S_.v_m_ + P_.E_L_; + S_.error_signal_ = S_.readout_signal_ - S_.target_signal_; +} + +/* ---------------------------------------------------------------- + * Event handling functions + * ---------------------------------------------------------------- */ + +void +eprop_readout::handle( DelayedRateConnectionEvent& e ) +{ + const size_t rport = e.get_rport(); + assert( rport < SUP_RATE_RECEPTOR ); + + auto it = e.begin(); + assert( it != e.end() ); + + const double signal = e.get_weight() * e.get_coeffvalue( it ); + if ( rport == LEARNING_WINDOW_SIG ) + { + S_.learning_window_signal_ = signal; + } + else if ( rport == TARGET_SIG ) + { + S_.target_signal_ = signal; + } + + assert( it == e.end() ); +} + +void +eprop_readout::handle( SpikeEvent& e ) +{ + assert( e.get_delay_steps() > 0 ); + + B_.spikes_.add_value( + e.get_rel_delivery_steps( kernel().simulation_manager.get_slice_origin() ), e.get_weight() * e.get_multiplicity() ); +} + +void +eprop_readout::handle( CurrentEvent& e ) +{ + assert( e.get_delay_steps() > 0 ); + + B_.currents_.add_value( + e.get_rel_delivery_steps( kernel().simulation_manager.get_slice_origin() ), e.get_weight() * e.get_current() ); +} + +void +eprop_readout::handle( DataLoggingRequest& e ) +{ + B_.logger_.handle( e ); +} + +void +eprop_readout::compute_gradient( const long t_spike, + const long t_spike_previous, + double& z_previous_buffer, + double& z_bar, + double& e_bar, + double& e_bar_reg, + double& epsilon, + double& weight, + const CommonSynapseProperties& cp, + WeightOptimizer* optimizer ) +{ + double z = 0.0; // spiking variable + double z_current_buffer = 1.0; // buffer containing the spike that triggered the current integration + double L = 0.0; // error signal + double grad = 0.0; // gradient + + const EpropSynapseCommonProperties& ecp = static_cast< const EpropSynapseCommonProperties& >( cp ); + const auto optimize_each_step = ( *ecp.optimizer_cp_ ).optimize_each_step_; + + auto eprop_hist_it = get_eprop_history( t_spike_previous - 1 ); + + const long t_compute_until = std::min( t_spike_previous + V_.eprop_isi_trace_cutoff_steps_, t_spike ); + + for ( long t = t_spike_previous; t < t_compute_until; ++t, ++eprop_hist_it ) + { + z = z_previous_buffer; + z_previous_buffer = z_current_buffer; + z_current_buffer = 0.0; + + L = eprop_hist_it->error_signal_; + + z_bar = V_.P_v_m_ * z_bar + V_.P_z_in_ * z; + + if ( optimize_each_step ) + { + grad = L * z_bar; + weight = optimizer->optimized_weight( *ecp.optimizer_cp_, t, grad, weight ); + } + else + { + grad += L * z_bar; + } + } + + if ( not optimize_each_step ) + { + weight = optimizer->optimized_weight( *ecp.optimizer_cp_, t_compute_until, grad, weight ); + } + + const long cutoff_to_spike_interval = t_spike - t_compute_until; + + if ( cutoff_to_spike_interval > 0 ) + { + z_bar *= std::pow( V_.P_v_m_, cutoff_to_spike_interval ); + } +} + +} // namespace nest diff --git a/models/eprop_readout.h b/models/eprop_readout.h new file mode 100644 index 0000000000..fd7e4ac7af --- /dev/null +++ b/models/eprop_readout.h @@ -0,0 +1,586 @@ +/* + * eprop_readout.h + * + * This file is part of NEST. + * + * Copyright (C) 2004 The NEST Initiative + * + * NEST is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * NEST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with NEST. If not, see . + * + */ + +#ifndef EPROP_READOUT_H +#define EPROP_READOUT_H + +// nestkernel +#include "connection.h" +#include "eprop_archiving_node.h" +#include "eprop_archiving_node_impl.h" +#include "eprop_synapse.h" +#include "event.h" +#include "nest_types.h" +#include "ring_buffer.h" +#include "universal_data_logger.h" + +namespace nest +{ + +/* BeginUserDocs: neuron, e-prop plasticity, current-based + +Short description ++++++++++++++++++ + +Current-based leaky integrate readout neuron model with delta-shaped +postsynaptic currents for e-prop plasticity + +Description ++++++++++++ + +``eprop_readout`` is an implementation of an integrate-and-fire neuron model +with delta-shaped postsynaptic currents used as readout neuron for eligibility propagation (e-prop) plasticity. + +E-prop plasticity was originally introduced and implemented in TensorFlow in [1]_. + +The membrane voltage time course :math:`v_j^t` of the neuron :math:`j` is given by: + +.. math:: + v_j^t &= \kappa v_j^{t-1}+ \zeta \sum_{i \neq j} W_{ji}^\text{out} z_i^{t-1} \,, \\ + \kappa &= e^{ -\frac{ \Delta t }{ \tau_\text{m} } } \,, \\ + \zeta &= + \begin{cases} + 1 \\ + 1 - \kappa + \end{cases} \,, \\ + +where :math:`W_{ji}^\text{out}` is the output synaptic weight matrix and +:math:`z_i^{t-1}` is the recurrent presynaptic spike state variable. + +Descriptions of further parameters and variables can be found in the table below. + +The spike state variable of a presynaptic neuron is expressed by a Heaviside function: + +.. math:: + z_i^t = H \left( v_i^t - v_\text{th} \right) \,. \\ + +An additional state variable and the corresponding differential equation +represents a piecewise constant external current. + +See the documentation on the :doc:`iaf_psc_delta<../models/iaf_psc_delta/>` neuron model +for more information on the integration of the subthreshold dynamics. + +The change of the synaptic weight is calculated from the gradient :math:`g^t` of +the loss :math:`E^t` with respect to the synaptic weight :math:`W_{ji}`: +:math:`\frac{ \text{d} E^t }{ \text{d} W_{ij} }` +which depends on the presynaptic +spikes :math:`z_i^{t-1}` and the learning signal :math:`L_j^t` emitted by the readout +neurons. + +In the interval between two presynaptic spikes, the gradient is calculated +at each time step until the cutoff time point. This computation occurs over +the time range: + +:math:`t \in \left[ t_\text{spk,prev}, \min \left( t_\text{spk,prev} + \Delta t_\text{c}, t_\text{spk,curr} \right) +\right]`. + +Here, :math:`t_\text{spk,prev}` represents the time of the previous spike that +passed the synapse, while :math:`t_\text{spk,curr}` is the time of the +current spike, which triggers the application of the learning rule and the +subsequent synaptic weight update. The cutoff :math:`\Delta t_\text{c}` +defines the maximum allowable interval for integration between spikes. +The expression for the gradient is given by: + +.. math:: + \frac{ \text{d} E^t }{ \text{d} W_{ji} } = L_j^t \bar{z}_i^{t-1} \,. \\ + +The presynaptic spike trains are low-pass filtered with the following exponential kernel: + +.. math:: + \bar{z}_i^t = \mathcal{F}_\kappa \left( z_{i}^t \right) + = \kappa \bar{z}_i^{t-1} + \zeta z_i^t \,. \\ + +Since readout neurons are leaky integrators without a spiking mechanism, the +formula for computing the gradient lacks the surrogate gradient / +pseudo-derivative and a firing regularization term. + +As a last step for every round in the loop over the time steps :math:`t`, the new weight is retrieved by feeding the +current gradient :math:`g^t` to the optimizer (see :doc:`weight_optimizer<../models/weight_optimizer/>` +for more information on the available optimizers): + +.. math:: + w^t = \text{optimizer} \left( t, g^t, w^{t-1} \right) \,. \\ + +After the loop has terminated, the filtered dynamic variables of e-prop are propagated from the end of the cutoff until +the next spike: + +.. math:: + p &= \text{max} \left( 0, t_\text{s}^{t} - \left( t_\text{s}^{t-1} + {\Delta t}_\text{c} \right) \right) \,, \\ + \bar{z}_i^{t+p} &= \bar{z}_i^t \alpha^p \,. \\ + +The learning signal :math:`L_j^t` is given by the non-plastic feedback weight +matrix :math:`B_{jk}` and the continuous error signal :math:`e_k^t` emitted by +readout neuron :math:`k` and :math:`e_k^t` defined via a mean-squared error +loss: + +.. math:: + L_j^t = B_{jk} e_k^t = B_{jk} \left( y_k^t - y_k^{*,t} \right) \,. \\ + +where the readout signal :math:`y_k^t` corresponds to the membrane voltage of +readout neuron :math:`k` and :math:`y_k^{*,t}` is the real-valued target signal. + +Furthermore, the readout and target signal are multiplied by a learning window +signal, which has a value of 1.0 within the learning window and 0.0 outside. + +For more information on e-prop plasticity, see the documentation on the other e-prop models: + + * :doc:`eprop_iaf<../models/eprop_iaf/>` + * :doc:`eprop_iaf_adapt<../models/eprop_iaf_adapt/>` + * :doc:`eprop_synapse<../models/eprop_synapse/>` + * :doc:`eprop_learning_signal_connection<../models/eprop_learning_signal_connection/>` + +Details on the event-based NEST implementation of e-prop can be found in [2]_. + +Parameters +++++++++++ + +The following parameters can be set in the status dictionary. + +========================= ======= ===================== ================== ===================================== +**Neuron parameters** +---------------------------------------------------------------------------------------------------------------- +Parameter Unit Math equivalent Default Description +========================= ======= ===================== ================== ===================================== +``C_m`` pF :math:`C_\text{m}` 250.0 Capacitance of the membrane +``E_L`` mV :math:`E_\text{L}` 0.0 Leak / resting membrane potential +``I_e`` pA :math:`I_\text{e}` 0.0 Constant external input current +``regular_spike_arrival`` Boolean ``True`` If ``True``, the input spikes arrive + at the end of the time step, if + ``False`` at the beginning + (determines PSC scale) +``tau_m`` ms :math:`\tau_\text{m}` 10.0 Time constant of the membrane +``V_min`` mV :math:`v_\text{min}` negative maximum Absolute lower bound of the membrane + value voltage + representable by a + ``double`` type in + C++ +========================= ======= ===================== ================== ===================================== + +=========================== ======= =========================== ================ =============================== +**E-prop parameters** +---------------------------------------------------------------------------------------------------------------- +Parameter Unit Math equivalent Default Description +=========================== ======= =========================== ================ =============================== +``eprop_isi_trace_cutoff`` ms :math:`{\Delta t}_\text{c}` maximum value Cutoff for integration of + representable e-prop update between two + by a ``long`` spikes + type in C++ +=========================== ======= =========================== ================ =============================== + +Recordables ++++++++++++ + +The following state variables evolve during simulation and can be recorded. + +=============== ==== =============== ============= ================ +**Neuron state variables and recordables** +------------------------------------------------------------------- +State variable Unit Math equivalent Initial value Description +=============== ==== =============== ============= ================ +``V_m`` mV :math:`v_j` 0.0 Membrane voltage +=============== ==== =============== ============= ================ + +========================= ==== =============== ============= ============== +**E-prop state variables and recordables** +--------------------------------------------------------------------------- +State variable Unit Math equivalent Initial value Description +========================= ==== =============== ============= ============== +``error_signal`` mV :math:`L_j` 0.0 Error signal +``readout_signal`` mV :math:`y_j` 0.0 Readout signal +``target_signal`` mV :math:`y^*_j` 0.0 Target signal +========================= ==== =============== ============= ============== + +Usage ++++++ + +This model can only be used in combination with the other e-prop models +and the network architecture requires specific wiring, input, and output. +The usage is demonstrated in several +:doc:`supervised regression and classification tasks <../auto_examples/eprop_plasticity/index>` +reproducing among others the original proof-of-concept tasks in [1]_. + +References +++++++++++ + +.. [1] Bellec G, Scherr F, Subramoney F, Hajek E, Salaj D, Legenstein R, + Maass W (2020). A solution to the learning dilemma for recurrent + networks of spiking neurons. Nature Communications, 11:3625. + https://doi.org/10.1038/s41467-020-17236-y + +.. [2] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Plesser HE, + Dahmen D, Bolten M, Van Albada SJ*, Diesmann M*. Event-based + implementation of eligibility propagation (in preparation) + +Sends ++++++ + +LearningSignalConnectionEvent, DelayedRateConnectionEvent + +Receives +++++++++ + +SpikeEvent, CurrentEvent, DelayedRateConnectionEvent, DataLoggingRequest + +See also +++++++++ + +Examples using this model ++++++++++++++++++++++++++ + +.. listexamples:: eprop_readout + +EndUserDocs */ + +void register_eprop_readout( const std::string& name ); + +/** + * @brief Class implementing a readout neuron model for e-prop plasticity with additional biological features. + * + * Class implementing a current-based leaky integrate readout neuron model with delta-shaped postsynaptic currents for + * e-prop plasticity according to Bellec et al. (2020) with additional biological features described in + * Korcsak-Gorzo, Stapmanns, and Espinoza Valverde et al. (in preparation). + */ +class eprop_readout : public EpropArchivingNodeReadout +{ + +public: + //! Default constructor. + eprop_readout(); + + //! Copy constructor. + eprop_readout( const eprop_readout& ); + + using Node::handle; + using Node::handles_test_event; + + using Node::sends_secondary_event; + + void + sends_secondary_event( LearningSignalConnectionEvent& ) override + { + } + + void + sends_secondary_event( DelayedRateConnectionEvent& ) override + { + } + + void handle( SpikeEvent& ) override; + void handle( CurrentEvent& ) override; + void handle( DelayedRateConnectionEvent& ) override; + void handle( DataLoggingRequest& ) override; + + size_t handles_test_event( SpikeEvent&, size_t ) override; + size_t handles_test_event( CurrentEvent&, size_t ) override; + size_t handles_test_event( DelayedRateConnectionEvent&, size_t ) override; + size_t handles_test_event( DataLoggingRequest&, size_t ) override; + + void get_status( DictionaryDatum& ) const override; + void set_status( const DictionaryDatum& ) override; + +private: + void init_buffers_() override; + void pre_run_hook() override; + + void update( Time const&, const long, const long ) override; + + void compute_gradient( const long, + const long, + double&, + double&, + double&, + double&, + double&, + double&, + const CommonSynapseProperties&, + WeightOptimizer* ) override; + + long get_shift() const override; + bool is_eprop_recurrent_node() const override; + long get_eprop_isi_trace_cutoff() const override; + + //! Compute the error signal based on the mean-squared error loss. + void compute_error_signal_mean_squared_error( const long lag ); + + //! Compute the error signal based on a loss function. + void ( eprop_readout::*compute_error_signal )( const long lag ); + + //! Map for storing a static set of recordables. + friend class RecordablesMap< eprop_readout >; + + //! Logger for universal data supporting the data logging request / reply mechanism. Populated with a recordables map. + friend class UniversalDataLogger< eprop_readout >; + + //! Structure of parameters. + struct Parameters_ + { + //! Capacitance of the membrane (pF). + double C_m_; + + //! Leak / resting membrane potential (mV). + double E_L_; + + //! Constant external input current (pA). + double I_e_; + + //! If True, the input spikes arrive at the beginning of the time step, if False at the end (determines PSC scale). + bool regular_spike_arrival_; + + //! Time constant of the membrane (ms). + double tau_m_; + + //! Absolute lower bound of the membrane voltage relative to the leak membrane potential (mV). + double V_min_; + + //! Time interval from the previous spike until the cutoff of e-prop update integration between two spikes (ms). + double eprop_isi_trace_cutoff_; + + //! Default constructor. + Parameters_(); + + //! Get the parameters and their values. + void get( DictionaryDatum& ) const; + + //! Set the parameters and throw errors in case of invalid values. + double set( const DictionaryDatum&, Node* ); + }; + + //! Structure of state variables. + struct State_ + { + //! Error signal. Deviation between the readout and the target signal. + double error_signal_; + + //! Readout signal. Leaky integrated spikes emitted by the recurrent network. + double readout_signal_; + + //! Target / teacher signal that the network is supposed to learn. + double target_signal_; + + //! Signal indicating whether the readout neurons are in a learning phase. + double learning_window_signal_; + + //! Input current (pA). + double i_in_; + + //! Membrane voltage relative to the leak membrane potential (mV). + double v_m_; + + //! Binary input spike state variable - 1.0 if the neuron has spiked in the previous time step and 0.0 otherwise. + double z_in_; + + //! Default constructor. + State_(); + + //! Get the state variables and their values. + void get( DictionaryDatum&, const Parameters_& ) const; + + //! Set the state variables. + void set( const DictionaryDatum&, const Parameters_&, double, Node* ); + }; + + //! Structure of buffers. + struct Buffers_ + { + //! Default constructor. + Buffers_( eprop_readout& ); + + //! Copy constructor. + Buffers_( const Buffers_&, eprop_readout& ); + + //! Buffer for incoming spikes. + RingBuffer spikes_; + + //! Buffer for incoming currents. + RingBuffer currents_; + + //! Logger for universal data. + UniversalDataLogger< eprop_readout > logger_; + }; + + //! Structure of internal variables. + struct Variables_ + { + //! Propagator matrix entry for evolving the membrane voltage (mathematical symbol "kappa" in user documentation). + double P_v_m_; + + //! Propagator matrix entry for evolving the incoming spike state variables (mathematical symbol "zeta" in user + //! documentation). + double P_z_in_; + + //! Propagator matrix entry for evolving the incoming currents. + double P_i_in_; + + //! Time steps from the previous spike until the cutoff of e-prop update integration between two spikes. + long eprop_isi_trace_cutoff_steps_; + }; + + //! Minimal spike receptor type. Start with 1 to forbid port 0 and avoid accidental creation of connections with no + //! receptor type set. + static const size_t MIN_RATE_RECEPTOR = 1; + + //! Enumeration of spike receptor types. + enum RateSynapseTypes + { + LEARNING_WINDOW_SIG = MIN_RATE_RECEPTOR, + TARGET_SIG, + SUP_RATE_RECEPTOR + }; + + //! Get the current value of the membrane voltage. + double + get_v_m_() const + { + return S_.v_m_ + P_.E_L_; + } + + //! Get the current value of the normalized readout signal. + double + get_readout_signal_() const + { + return S_.readout_signal_; + } + + //! Get the current value of the target signal. + double + get_target_signal_() const + { + return S_.target_signal_; + } + + //! Get the current value of the error signal. + double + get_error_signal_() const + { + return S_.error_signal_; + } + + // the order in which the structure instances are defined is important for speed + + //! Structure of parameters. + Parameters_ P_; + + //! Structure of state variables. + State_ S_; + + //! Structure of internal variables. + Variables_ V_; + + //! Structure of buffers. + Buffers_ B_; + + //! Map storing a static set of recordables. + static RecordablesMap< eprop_readout > recordablesMap_; +}; + +inline long +eprop_readout::get_eprop_isi_trace_cutoff() const +{ + return V_.eprop_isi_trace_cutoff_steps_; +} + +inline size_t +eprop_readout::handles_test_event( SpikeEvent&, size_t receptor_type ) +{ + if ( receptor_type != 0 ) + { + throw UnknownReceptorType( receptor_type, get_name() ); + } + + return 0; +} + +inline size_t +eprop_readout::handles_test_event( CurrentEvent&, size_t receptor_type ) +{ + if ( receptor_type != 0 ) + { + throw UnknownReceptorType( receptor_type, get_name() ); + } + + return 0; +} + +inline size_t +eprop_readout::handles_test_event( DelayedRateConnectionEvent& e, size_t receptor_type ) +{ + size_t step_rate_model_id = kernel().model_manager.get_node_model_id( "step_rate_generator" ); + size_t model_id = e.get_sender().get_model_id(); + + if ( step_rate_model_id == model_id and receptor_type != TARGET_SIG and receptor_type != LEARNING_WINDOW_SIG ) + { + throw IllegalConnection( + "eprop_readout neurons expect a connection with a step_rate_generator node through receptor_type " + "1 or 2." ); + } + + if ( receptor_type < MIN_RATE_RECEPTOR or receptor_type >= SUP_RATE_RECEPTOR ) + { + throw UnknownReceptorType( receptor_type, get_name() ); + } + + return receptor_type; +} + +inline size_t +eprop_readout::handles_test_event( DataLoggingRequest& dlr, size_t receptor_type ) +{ + if ( receptor_type != 0 ) + { + throw UnknownReceptorType( receptor_type, get_name() ); + } + + return B_.logger_.connect_logging_device( dlr, recordablesMap_ ); +} + +inline void +eprop_readout::get_status( DictionaryDatum& d ) const +{ + P_.get( d ); + S_.get( d, P_ ); + ( *d )[ names::recordables ] = recordablesMap_.get_list(); + + DictionaryDatum receptor_dict_ = new Dictionary(); + ( *receptor_dict_ )[ names::eprop_learning_window ] = LEARNING_WINDOW_SIG; + ( *receptor_dict_ )[ names::target_signal ] = TARGET_SIG; + + ( *d )[ names::receptor_types ] = receptor_dict_; +} + +inline void +eprop_readout::set_status( const DictionaryDatum& d ) +{ + // temporary copies in case of errors + Parameters_ ptmp = P_; + State_ stmp = S_; + + // make sure that ptmp and stmp consistent - throw BadProperty if not + const double delta_EL = ptmp.set( d, this ); + stmp.set( d, ptmp, delta_EL, this ); + + P_ = ptmp; + S_ = stmp; +} + +} // namespace nest + +#endif // EPROP_READOUT_H diff --git a/models/eprop_readout_bsshslm_2020.cpp b/models/eprop_readout_bsshslm_2020.cpp index 76317bc643..50027fe621 100644 --- a/models/eprop_readout_bsshslm_2020.cpp +++ b/models/eprop_readout_bsshslm_2020.cpp @@ -219,7 +219,7 @@ eprop_readout_bsshslm_2020::pre_run_hook() const double dt = Time::get_resolution().get_ms(); - V_.P_v_m_ = std::exp( -dt / P_.tau_m_ ); // called kappa in reference [1] + V_.P_v_m_ = std::exp( -dt / P_.tau_m_ ); V_.P_i_in_ = P_.tau_m_ / P_.C_m_ * ( 1.0 - V_.P_v_m_ ); V_.P_z_in_ = P_.regular_spike_arrival_ ? 1.0 : 1.0 - V_.P_v_m_; } @@ -259,10 +259,8 @@ eprop_readout_bsshslm_2020::update( Time const& origin, const long from, const l const long interval_step = ( t - shift ) % update_interval; const long interval_step_signals = ( t - shift - delay_out_norm_ ) % update_interval; - if ( interval_step == 0 ) { - erase_used_update_history(); erase_used_eprop_history(); if ( with_reset ) @@ -294,6 +292,7 @@ eprop_readout_bsshslm_2020::update( Time const& origin, const long from, const l error_signal_buffer[ lag ] = S_.error_signal_; + append_new_eprop_history_entry( t ); write_error_signal_to_history( t, S_.error_signal_ ); S_.i_in_ = B_.currents_.get_value( lag ) + P_.I_e_; diff --git a/models/eprop_readout_bsshslm_2020.h b/models/eprop_readout_bsshslm_2020.h index ba25d07d36..8a708269ed 100644 --- a/models/eprop_readout_bsshslm_2020.h +++ b/models/eprop_readout_bsshslm_2020.h @@ -46,7 +46,7 @@ postsynaptic currents for e-prop plasticity Description +++++++++++ -``eprop_readout_bsshslm_2020`` is an implementation of a integrate-and-fire neuron model +``eprop_readout_bsshslm_2020`` is an implementation of an integrate-and-fire neuron model with delta-shaped postsynaptic currents used as readout neuron for eligibility propagation (e-prop) plasticity. E-prop plasticity was originally introduced and implemented in TensorFlow in [1]_. @@ -55,47 +55,84 @@ The suffix ``_bsshslm_2020`` follows the NEST convention to indicate in the model name the paper that introduced it by the first letter of the authors' last names and the publication year. - The membrane voltage time course :math:`v_j^t` of the neuron :math:`j` is given by: .. math:: - v_j^t &= \kappa v_j^{t-1}+\sum_{i \neq j}W_{ji}^\mathrm{out}z_i^{t-1} - -z_j^{t-1}v_\mathrm{th} \,, \\ - \kappa &= e^{-\frac{\Delta t}{\tau_\mathrm{m}}} \,, + v_j^t &= \kappa v_j^{t-1} + \zeta \sum_{i \neq j} W_{ji}^\text{out} z_i^{t-1} \,, \\ + \kappa &= e^{ -\frac{ \Delta t }{ \tau_\text{m} } } \,, \\ + \zeta &= + \begin{cases} + 1 \\ + 1 - \kappa + \end{cases} \,, \\ -whereby :math:`W_{ji}^\mathrm{out}` are the output synaptic weights and -:math:`z_i^{t-1}` are the recurrent presynaptic spike state variables. +where :math:`W_{ji}^\text{out}` is the output synaptic weight matrix and +:math:`z_i^{t-1}` is the recurrent presynaptic spike state variable. Descriptions of further parameters and variables can be found in the table below. -An additional state variable and the corresponding differential -equation represents a piecewise constant external current. +The spike state variable of a presynaptic neuron is expressed by a Heaviside function: + +.. math:: + z_i^t = H \left( v_i^t - v_\text{th} \right) \,. \\ -See the documentation on the ``iaf_psc_delta`` neuron model for more information -on the integration of the subthreshold dynamics. +An additional state variable and the corresponding differential equation +represents a piecewise constant external current. + +See the documentation on the :doc:`iaf_psc_delta<../models/iaf_psc_delta/>` neuron model +for more information on the integration of the subthreshold dynamics. The change of the synaptic weight is calculated from the gradient :math:`g` of the loss :math:`E` with respect to the synaptic weight :math:`W_{ji}`: -The change of the synaptic weight is calculated from the gradient -:math:`\frac{\mathrm{d}{E}}{\mathrm{d}{W_{ij}}}=g` +:math:`\frac{ \text{d}E }{ \text{d} W_{ij} }` which depends on the presynaptic spikes :math:`z_i^{t-1}` and the learning signal :math:`L_j^t` emitted by the readout neurons. .. math:: - \frac{\mathrm{d}E}{\mathrm{d}W_{ji}} = g &= \sum_t L_j^t \bar{z}_i^{t-1}\,. \\ + \frac{ \text{d} E }{ \text{d} W_{ji} } = \sum_t L_j^t \bar{z}_i^{t-1} \,. \\ -The presynaptic spike trains are low-pass filtered with an exponential kernel: +The presynaptic spike trains are low-pass filtered with the following exponential kernel: .. math:: - \bar{z}_i^t &=\mathcal{F}_\kappa(z_i^t)\,, \\ - \mathcal{F}_\kappa(z_i^t) &= \kappa\, \mathcal{F}_\kappa(z_i^{t-1}) + z_i^t - \;\text{with}\, \mathcal{F}_\kappa(z_i^0)=z_i^0\,\,. + \bar{z}_i^t &=\mathcal{F}_\kappa(z_i^t) \,, \\ + \mathcal{F}_\kappa(z_i^t) &= \kappa \mathcal{F}_\kappa \left( z_i^{t-1} \right) + z_i^t \,, \\ + \mathcal{F}_\kappa(z_i^0) &= z_i^0 \,. \\ Since readout neurons are leaky integrators without a spiking mechanism, the formula for computing the gradient lacks the surrogate gradient / pseudo-derivative and a firing regularization term. +The learning signal :math:`L_j^t` is given by the non-plastic feedback weight +matrix :math:`B_{jk}` and the continuous error signal :math:`e_k^t` emitted by +readout neuron :math:`k`: + +.. math:: + L_j^t = B_{jk} e_k^t \,. \\ + +The error signal depends on the selected loss function. +If a mean squared error loss is selected, then: + +.. math:: + e_k^t = y_k^t - y_k^{*,t} \,, \\ + +where the readout signal :math:`y_k^t` corresponds to the membrane voltage of +readout neuron :math:`k` and :math:`y_k^{*,t}` is the real-valued target signal. + +If a cross-entropy loss is selected, then: + +.. math:: + e^k_t &= \pi_k^t - \pi_k^{*,t} \,, \\ + \pi_k^t &= \text{softmax}_k \left( y_1^t, ..., y_K^t \right) = + \frac{ \exp \left( y_k^t\right) }{ \sum_{k'} \exp \left( y_{k'}^t \right) } \,, \\ + +where the readout signal :math:`\pi_k^t` corresponds to the softmax of the +membrane voltage of readout neuron :math:`k` and :math:`\pi_k^{*,t}` is the +one-hot encoded target signal. + +Furthermore, the readout and target signal are zero before the onset of the +learning window in each update interval. + For more information on e-prop plasticity, see the documentation on the other e-prop models: * :doc:`eprop_iaf_bsshslm_2020<../models/eprop_iaf_bsshslm_2020/>` @@ -110,53 +147,64 @@ Parameters The following parameters can be set in the status dictionary. -===================== ======= ===================== ================== =============================================== +========================= ======= ===================== ================== ===================================== **Neuron parameters** ----------------------------------------------------------------------------------------------------------------------- -Parameter Unit Math equivalent Default Description -===================== ======= ===================== ================== =============================================== -C_m pF :math:`C_\text{m}` 250.0 Capacitance of the membrane -E_L mV :math:`E_\text{L}` 0.0 Leak / resting membrane potential -I_e pA :math:`I_\text{e}` 0.0 Constant external input current -loss :math:`E` mean_squared_error Loss function - ["mean_squared_error", "cross_entropy"] -regular_spike_arrival Boolean True If True, the input spikes arrive at the - end of the time step, if False at the - beginning (determines PSC scale) -tau_m ms :math:`\tau_\text{m}` 10.0 Time constant of the membrane -V_min mV :math:`v_\text{min}` -1.79e+308 Absolute lower bound of the membrane voltage -===================== ======= ===================== ================== =============================================== - -The following state variables evolve during simulation. - -===================== ==== =============== ============= ========================== -**Neuron state variables and recordables** ------------------------------------------------------------------------------------ -State variable Unit Math equivalent Initial value Description -===================== ==== =============== ============= ========================== -error_signal mV :math:`L_j` 0.0 Error signal -readout_signal mV :math:`y_j` 0.0 Readout signal -readout_signal_unnorm mV 0.0 Unnormalized readout signal -target_signal mV :math:`y^*_j` 0.0 Target signal -V_m mV :math:`v_j` 0.0 Membrane voltage -===================== ==== =============== ============= ========================== +---------------------------------------------------------------------------------------------------------------- +Parameter Unit Math equivalent Default Description +========================= ======= ===================== ================== ===================================== +``C_m`` pF :math:`C_\text{m}` 250.0 Capacitance of the membrane +``E_L`` mV :math:`E_\text{L}` 0.0 Leak / resting membrane potential +``I_e`` pA :math:`I_\text{e}` 0.0 Constant external input current +``regular_spike_arrival`` Boolean ``True`` If ``True``, the input spikes arrive + at the end of the time step, if + ``False`` at the beginning + (determines PSC scale) +``tau_m`` ms :math:`\tau_\text{m}` 10.0 Time constant of the membrane +``V_min`` mV :math:`v_\text{min}` negative maximum Absolute lower bound of the membrane + value voltage + representable by a + ``double`` type in + C++ +========================= ======= ===================== ================== ===================================== + +========== ======= ===================== ==================== ========================================= +**E-prop parameters** +------------------------------------------------------------------------------------------------------- +Parameter Unit Math equivalent Default Description +========== ======= ===================== ==================== ========================================= +``loss`` :math:`E` "mean_squared_error" Loss function + ["mean_squared_error", "cross_entropy"] +========== ======= ===================== ==================== ========================================= Recordables +++++++++++ -The following variables can be recorded: +The following state variables evolve during simulation and can be recorded. - - error signal ``error_signal`` - - readout signal ``readout_signal`` - - readout signal ``readout_signal_unnorm`` - - target signal ``target_signal`` - - membrane potential ``V_m`` +=============== ==== =============== ============= ================ +**Neuron state variables and recordables** +------------------------------------------------------------------- +State variable Unit Math equivalent Initial value Description +=============== ==== =============== ============= ================ +``V_m`` mV :math:`v_j` 0.0 Membrane voltage +=============== ==== =============== ============= ================ + +========================= ==== =============== ============= =============================== +**E-prop state variables and recordables** +-------------------------------------------------------------------------------------------- +State variable Unit Math equivalent Initial value Description +========================= ==== =============== ============= =============================== +``error_signal`` mV :math:`L_j` 0.0 Error signal +``readout_signal`` mV :math:`y_j` 0.0 Readout signal +``readout_signal_unnorm`` mV 0.0 Unnormalized readout signal +``target_signal`` mV :math:`y^*_j` 0.0 Target signal +========================= ==== =============== ============= =============================== Usage +++++ -This model can only be used in combination with the other e-prop models, -whereby the network architecture requires specific wiring, input, and output. +This model can only be used in combination with the other e-prop models +and the network architecture requires specific wiring, input, and output. The usage is demonstrated in several :doc:`supervised regression and classification tasks <../auto_examples/eprop_plasticity/index>` reproducing among others the original proof-of-concept tasks in [1]_. @@ -168,12 +216,13 @@ References Maass W (2020). A solution to the learning dilemma for recurrent networks of spiking neurons. Nature Communications, 11:3625. https://doi.org/10.1038/s41467-020-17236-y -.. [2] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Dahmen D, - van Albada SJ, Bolten M, Diesmann M. Event-based implementation of - eligibility propagation (in preparation) + +.. [2] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Plesser HE, + Dahmen D, Bolten M, Van Albada SJ*, Diesmann M*. Event-based + implementation of eligibility propagation (in preparation) Sends -++++++++ ++++++ LearningSignalConnectionEvent, DelayedRateConnectionEvent @@ -186,7 +235,7 @@ See also ++++++++ Examples using this model -++++++++++++++++++++++++++ ++++++++++++++++++++++++++ .. listexamples:: eprop_readout_bsshslm_2020 @@ -195,6 +244,8 @@ EndUserDocs */ void register_eprop_readout_bsshslm_2020( const std::string& name ); /** + * @brief Class implementing a readout neuron model for e-prop plasticity. + * * Class implementing a current-based leaky integrate readout neuron model with delta-shaped postsynaptic currents for * e-prop plasticity according to Bellec et al. (2020). */ @@ -236,21 +287,21 @@ class eprop_readout_bsshslm_2020 : public EpropArchivingNodeReadout void get_status( DictionaryDatum& ) const override; void set_status( const DictionaryDatum& ) override; +private: + void init_buffers_() override; + void pre_run_hook() override; + + void update( Time const&, const long, const long ) override; + double compute_gradient( std::vector< long >& presyn_isis, const long t_previous_update, const long t_previous_trigger_spike, const double kappa, const bool average_gradient ) override; - void pre_run_hook() override; long get_shift() const override; bool is_eprop_recurrent_node() const override; - void update( Time const&, const long, const long ) override; - -protected: - void init_buffers_() override; -private: //! Compute the error signal based on the mean-squared error loss. void compute_error_signal_mean_squared_error( const long lag ); @@ -321,7 +372,7 @@ class eprop_readout_bsshslm_2020 : public EpropArchivingNodeReadout //! Membrane voltage relative to the leak membrane potential (mV). double v_m_; - //! Binary input spike variables - 1.0 if the neuron has spiked in the previous time step and 0.0 otherwise. + //! Binary input spike state variable - 1.0 if the neuron has spiked in the previous time step and 0.0 otherwise. double z_in_; //! Default constructor. @@ -356,13 +407,14 @@ class eprop_readout_bsshslm_2020 : public EpropArchivingNodeReadout UniversalDataLogger< eprop_readout_bsshslm_2020 > logger_; }; - //! Structure of general variables. + //! Structure of internal variables. struct Variables_ { - //! Propagator matrix entry for evolving the membrane voltage. + //! Propagator matrix entry for evolving the membrane voltage (mathematical symbol "kappa" in user documentation). double P_v_m_; - //! Propagator matrix entry for evolving the incoming spike variables. + //! Propagator matrix entry for evolving the incoming spike state variables (mathematical symbol "zeta" in user + //! documentation). double P_z_in_; //! Propagator matrix entry for evolving the incoming currents. @@ -421,16 +473,16 @@ class eprop_readout_bsshslm_2020 : public EpropArchivingNodeReadout // the order in which the structure instances are defined is important for speed - //!< Structure of parameters. + //! Structure of parameters. Parameters_ P_; - //!< Structure of state variables. + //! Structure of state variables. State_ S_; - //!< Structure of general variables. + //! Structure of internal variables. Variables_ V_; - //!< Structure of buffers. + //! Structure of buffers. Buffers_ B_; //! Map storing a static set of recordables. diff --git a/models/eprop_synapse.cpp b/models/eprop_synapse.cpp new file mode 100644 index 0000000000..f167592024 --- /dev/null +++ b/models/eprop_synapse.cpp @@ -0,0 +1,145 @@ +/* + * eprop_synapse.cpp + * + * This file is part of NEST. + * + * Copyright (C) 2004 The NEST Initiative + * + * NEST is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * NEST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with NEST. If not, see . + * + */ + +#include "eprop_synapse.h" + +// nestkernel +#include "nest_impl.h" + +namespace nest +{ + +void +register_eprop_synapse( const std::string& name ) +{ + register_connection_model< eprop_synapse >( name ); +} + +EpropSynapseCommonProperties::EpropSynapseCommonProperties() + : CommonSynapseProperties() + , optimizer_cp_( new WeightOptimizerCommonPropertiesGradientDescent() ) +{ +} + +EpropSynapseCommonProperties::EpropSynapseCommonProperties( const EpropSynapseCommonProperties& cp ) + : CommonSynapseProperties( cp ) + , optimizer_cp_( cp.optimizer_cp_->clone() ) +{ +} + +EpropSynapseCommonProperties::~EpropSynapseCommonProperties() +{ + delete optimizer_cp_; +} + +void +EpropSynapseCommonProperties::get_status( DictionaryDatum& d ) const +{ + CommonSynapseProperties::get_status( d ); + def< std::string >( d, names::optimizer, optimizer_cp_->get_name() ); + DictionaryDatum optimizer_dict = new Dictionary; + optimizer_cp_->get_status( optimizer_dict ); + ( *d )[ names::optimizer ] = optimizer_dict; +} + +void +EpropSynapseCommonProperties::set_status( const DictionaryDatum& d, ConnectorModel& cm ) +{ + CommonSynapseProperties::set_status( d, cm ); + + if ( d->known( names::optimizer ) ) + { + DictionaryDatum optimizer_dict = getValue< DictionaryDatum >( d->lookup( names::optimizer ) ); + + std::string new_optimizer; + const bool set_optimizer = updateValue< std::string >( optimizer_dict, names::type, new_optimizer ); + if ( set_optimizer and new_optimizer != optimizer_cp_->get_name() ) + { + if ( kernel().connection_manager.get_num_connections( cm.get_syn_id() ) > 0 ) + { + throw BadParameter( "The optimizer cannot be changed because synapses have been created." ); + } + + // TODO: selection here should be based on an optimizer registry and a factory + // delete is in if/else if because we must delete only when we are sure that we have a valid optimizer + if ( new_optimizer == "gradient_descent" ) + { + delete optimizer_cp_; + optimizer_cp_ = new WeightOptimizerCommonPropertiesGradientDescent(); + } + else if ( new_optimizer == "adam" ) + { + delete optimizer_cp_; + optimizer_cp_ = new WeightOptimizerCommonPropertiesAdam(); + } + else + { + throw BadProperty( "optimizer from [\"gradient_descent\", \"adam\"] required." ); + } + } + + // we can now set the defaults on the new optimizer common properties + optimizer_cp_->set_status( optimizer_dict ); + } +} + +template <> +void +Connector< eprop_synapse< TargetIdentifierPtrRport > >::disable_connection( const size_t lcid ) +{ + assert( not C_[ lcid ].is_disabled() ); + C_[ lcid ].disable(); + C_[ lcid ].delete_optimizer(); +} + +template <> +void +Connector< eprop_synapse< TargetIdentifierIndex > >::disable_connection( const size_t lcid ) +{ + assert( not C_[ lcid ].is_disabled() ); + C_[ lcid ].disable(); + C_[ lcid ].delete_optimizer(); +} + + +template <> +Connector< eprop_synapse< TargetIdentifierPtrRport > >::~Connector() +{ + for ( auto& c : C_ ) + { + c.delete_optimizer(); + } + C_.clear(); +} + +template <> +Connector< eprop_synapse< TargetIdentifierIndex > >::~Connector() +{ + for ( auto& c : C_ ) + { + c.delete_optimizer(); + } + C_.clear(); +} + + +} // namespace nest diff --git a/models/eprop_synapse.h b/models/eprop_synapse.h new file mode 100644 index 0000000000..ec36f8f025 --- /dev/null +++ b/models/eprop_synapse.h @@ -0,0 +1,563 @@ +/* + * eprop_synapse.h + * + * This file is part of NEST. + * + * Copyright (C) 2004 The NEST Initiative + * + * NEST is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * NEST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with NEST. If not, see . + * + */ + +#ifndef EPROP_SYNAPSE_H +#define EPROP_SYNAPSE_H + +// nestkernel +#include "connection.h" +#include "connector_base.h" +#include "eprop_archiving_node.h" +#include "target_identifier.h" +#include "weight_optimizer.h" + +namespace nest +{ + +/* BeginUserDocs: synapse, e-prop plasticity + +Short description ++++++++++++++++++ + +Synapse type for e-prop plasticity + +Description ++++++++++++ + +``eprop_synapse`` is an implementation of a connector model to create synapses between postsynaptic +neurons :math:`j` and presynaptic neurons and :math:`i` for eligibility propagation (e-prop) plasticity. + +E-prop plasticity was originally introduced and implemented in TensorFlow in [1]_. + +The e-prop synapse triggers the calculation of the gradient at each spike +over an interval that begins at the previous spike and ends at a cutoff specified by the user or the +current spike, depending on which of the two time points is earlier. +The gradient calculation is specific to the post-synaptic neuron and thus defined there. + +Eventually, it optimizes the weight with the specified optimizer. + +E-prop synapses require archiving of continuous quantities. Therefore e-prop +synapses can only be connected to neuron models that are capable of +archiving. So far, compatible models are ``eprop_iaf``, ``eprop_iaf_psc_delta``, ``eprop_iaf_psc_delta_adapt``, +``eprop_iaf_adapt``, and ``eprop_readout``. + +For more information on e-prop plasticity, see the documentation on the other e-prop models: + + * :doc:`eprop_iaf<../models/eprop_iaf/>` + * :doc:`eprop_iaf_adapt<../models/eprop_iaf_adapt/>` + * :doc:`eprop_readout<../models/eprop_readout/>` + * :doc:`eprop_learning_signal_connection<../models/eprop_learning_signal_connection/>` + +For more information on the optimizers, see the documentation of the weight optimizer: + + * :doc:`weight_optimizer<../models/weight_optimizer/>` + +Details on the event-based NEST implementation of e-prop can be found in [2]_. + +.. warning:: + + This synaptic plasticity rule does not take + :ref:`precise spike timing ` into + account. When calculating the weight update, the precise spike time part + of the timestamp is ignored. + +Parameters +++++++++++ + +The following parameters can be set in the status dictionary. + +================ ==== =============== ======= ====================================================== +**Common e-prop synapse parameters** +---------------------------------------------------------------------------------------------------- +Parameter Unit Math equivalent Default Description +================ ==== =============== ======= ====================================================== +``optimizer`` {} Dictionary of optimizer parameters +================ ==== =============== ======= ====================================================== + +============= ==== ========================= ======= ========================================================= +**Individual synapse parameters** +-------------------------------------------------------------------------------------------------------------- +Parameter Unit Math equivalent Default Description +============= ==== ========================= ======= ========================================================= +``delay`` ms :math:`d_{ji}` 1.0 Dendritic delay +``weight`` pA :math:`W_{ji}` 1.0 Initial value of synaptic weight +============= ==== ========================= ======= ========================================================= + +Recordables ++++++++++++ + +The following variables can be recorded. + +================== ==== =============== ============= ========================================================== +**Synapse recordables** +---------------------------------------------------------------------------------------------------------------- +State variable Unit Math equivalent Initial value Description +================== ==== =============== ============= ========================================================== +``weight`` pA :math:`B_{jk}` 1.0 Synaptic weight +================== ==== =============== ============= ========================================================== + +Usage ++++++ + +This model can only be used in combination with the other e-prop models +and the network architecture requires specific wiring, input, and output. +The usage is demonstrated in several +:doc:`supervised regression and classification tasks <../auto_examples/eprop_plasticity/index>` +reproducing among others the original proof-of-concept tasks in [1]_. + +Transmits ++++++++++ + +SpikeEvent, DSSpikeEvent + +References +++++++++++ + +.. [1] Bellec G, Scherr F, Subramoney F, Hajek E, Salaj D, Legenstein R, + Maass W (2020). A solution to the learning dilemma for recurrent + networks of spiking neurons. Nature Communications, 11:3625. + https://doi.org/10.1038/s41467-020-17236-y + +.. [2] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Plesser HE, + Dahmen D, Bolten M, Van Albada SJ*, Diesmann M*. Event-based + implementation of eligibility propagation (in preparation) + +See also +++++++++ + +Examples using this model ++++++++++++++++++++++++++ + +.. listexamples:: eprop_synapse + +EndUserDocs */ + +/** + * @brief Base class implementing common properties for e-prop synapses with additional biological features. + * + * Base class implementing common properties for the e-prop synapse model according to Bellec et al. (2020) with + * additional biological features described in Korcsak-Gorzo, Stapmanns, and Espinoza Valverde et al. + * (in preparation). + * + * This class in particular manages a pointer to weight-optimizer common properties to support + * exchanging the weight optimizer at runtime. Setting the weight-optimizer common properties + * determines the WO type. It can only be exchanged as long as no synapses for the model exist. + * The WO CP object is responsible for providing individual optimizer objects to synapses upon + * connection. + * + * @see WeightOptimizerCommonProperties + */ +class EpropSynapseCommonProperties : public CommonSynapseProperties +{ +public: + // Default constructor. + EpropSynapseCommonProperties(); + + //! Copy constructor. + EpropSynapseCommonProperties( const EpropSynapseCommonProperties& ); + + //! Assignment operator. + EpropSynapseCommonProperties& operator=( const EpropSynapseCommonProperties& ) = delete; + + //! Destructor. + ~EpropSynapseCommonProperties(); + + //! Get parameter dictionary. + void get_status( DictionaryDatum& d ) const; + + //! Update values in parameter dictionary. + void set_status( const DictionaryDatum& d, ConnectorModel& cm ); + + /** + * Pointer to common properties object for weight optimizer. + * + * @note Must only be changed as long as no synapses of the model exist. + */ + WeightOptimizerCommonProperties* optimizer_cp_; +}; + +//! Register the eprop synapse model. +void register_eprop_synapse( const std::string& name ); + +/** + * @brief Class implementing a synapse model for e-prop plasticity with additional biological features. + * + * Class implementing a synapse model for e-prop plasticity according to Bellec et al. (2020) with + * additional biological features described in Korcsak-Gorzo, Stapmanns, and Espinoza Valverde et al. (in preparation). + * + * @note Each synapse has a optimizer_ object managed through a `WeightOptimizer*`, pointing to an object of + * a specific weight optimizer type. This optimizer, drawing also on parameters in the `WeightOptimizerCommonProperties` + * accessible via the synapse models `CommonProperties::optimizer_cp_` pointer, computes the weight update for the + * neuron. The actual optimizer type can be selected at runtime (before creating any synapses) by exchanging the + * `optimizer_cp_` pointer. Individual optimizer objects are created by `check_connection()` when a synapse is actually + * created. It is important that the constructors of `eprop_synapse` **do not** create optimizer objects + * and that the destructor **does not** delete optimizer objects; this currently leads to bugs when using Boosts's + * `spreadsort()` due to use of the copy constructor where it should suffice to use the move constructor. Therefore, + * `check_connection()`creates the optimizer object when it is needed and specializations of `Connector::~Connector()` + * and `Connector::disable_connection()` delete it by calling `delete_optimizer()`. A disadvantage of this approach is + * that the `default_connection` in the connector model does not have an optimizer object, whence it is not possible to + * set default (initial) values for the per-synapse optimizer. + * + * @note If we can find a way to modify our co-sorting of source and target tables in Boost's `spreadsort()` to only use + * move operations, it should be possible to create the individual optimizers in the copy constructor of + * `eprop_synapse` and to delete it in the destructor. The `default_connection` can then own an optimizer + * and default values could be set on it. + */ +template < typename targetidentifierT > +class eprop_synapse : public Connection< targetidentifierT > +{ + +public: + //! Type of the common synapse properties. + typedef EpropSynapseCommonProperties CommonPropertiesType; + + //! Type of the connection base. + typedef Connection< targetidentifierT > ConnectionBase; + + /** + * Properties of the connection model. + * + * @note Does not support LBL at present because we cannot properly cast GenericModel common props in that case. + */ + static constexpr ConnectionModelProperties properties = ConnectionModelProperties::HAS_DELAY + | ConnectionModelProperties::IS_PRIMARY | ConnectionModelProperties::REQUIRES_EPROP_ARCHIVING + | ConnectionModelProperties::SUPPORTS_HPC; + + //! Default constructor. + eprop_synapse(); + + //! Destructor + ~eprop_synapse(); + + //! Parameterized copy constructor. + eprop_synapse( const eprop_synapse& ); + + //! Assignment operator + eprop_synapse& operator=( const eprop_synapse& ); + + //! Move constructor + eprop_synapse( eprop_synapse&& ); + + //! Move assignment operator + eprop_synapse& operator=( eprop_synapse&& ); + + using ConnectionBase::get_delay; + using ConnectionBase::get_delay_steps; + using ConnectionBase::get_rport; + using ConnectionBase::get_target; + + //! Get parameter dictionary. + void get_status( DictionaryDatum& d ) const; + + //! Update values in parameter dictionary. + void set_status( const DictionaryDatum& d, ConnectorModel& cm ); + + //! Send the spike event. + bool send( Event& e, size_t thread, const EpropSynapseCommonProperties& cp ); + + //! Dummy node for testing the connection. + class ConnTestDummyNode : public ConnTestDummyNodeBase + { + public: + using ConnTestDummyNodeBase::handles_test_event; + + size_t + handles_test_event( SpikeEvent&, size_t ) + { + return invalid_port; + } + + size_t + handles_test_event( DSSpikeEvent&, size_t ) + { + return invalid_port; + } + }; + + /** + * Check if the target accepts the event and receptor type requested by the sender. + * + * @note This sets the optimizer_ member. + */ + void check_connection( Node& s, Node& t, size_t receptor_type, const CommonPropertiesType& cp ); + + //! Set the synaptic weight to the provided value. + void + set_weight( const double w ) + { + weight_ = w; + } + + //! Delete optimizer + void delete_optimizer(); + +private: + //! Synaptic weight. + double weight_; + + //! The time step when the previous spike arrived. + long t_spike_previous_ = 0; + + //! The time step when the spike arrived that triggered the previous e-prop update. + long t_previous_trigger_spike_ = 0; + + //! Low-pass filtered spiking variable. + double z_bar_ = 0.0; + + //! Low-pass filtered eligibility trace. + double e_bar_ = 0.0; + + //! Low-pass filtered eligibility trace for firing rate regularization. + double e_bar_reg_ = 0.0; + + //! Adaptive threshold component of the eligibility vector. + double epsilon_ = 0.0; + + //! Value of spiking variable one time step before t_previous_spike_. + double z_previous_buffer_ = 0.0; + + /** + * Optimizer + * + * @note Pointer is set by check_connection() and deleted by delete_optimizer(). + */ + WeightOptimizer* optimizer_; +}; + +template < typename targetidentifierT > +constexpr ConnectionModelProperties eprop_synapse< targetidentifierT >::properties; + +// Explicitly declare specializations of Connector methods that need to do special things for eprop_synapse +template <> +void Connector< eprop_synapse< TargetIdentifierPtrRport > >::disable_connection( const size_t lcid ); + +template <> +void Connector< eprop_synapse< TargetIdentifierIndex > >::disable_connection( const size_t lcid ); + +template <> +Connector< eprop_synapse< TargetIdentifierPtrRport > >::~Connector(); + +template <> +Connector< eprop_synapse< TargetIdentifierIndex > >::~Connector(); + + +template < typename targetidentifierT > +eprop_synapse< targetidentifierT >::eprop_synapse() + : ConnectionBase() + , weight_( 1.0 ) + , t_spike_previous_( 0 ) + , t_previous_trigger_spike_( 0 ) + , optimizer_( nullptr ) +{ +} + +template < typename targetidentifierT > +eprop_synapse< targetidentifierT >::~eprop_synapse() +{ +} + +// This copy constructor is used to create instances from prototypes. +// Therefore, only parameter values are copied. +template < typename targetidentifierT > +eprop_synapse< targetidentifierT >::eprop_synapse( const eprop_synapse& es ) + : ConnectionBase( es ) + , weight_( es.weight_ ) + , optimizer_( es.optimizer_ ) +{ +} + +// This assignment operator is used to write a connection into the connection array. +template < typename targetidentifierT > +eprop_synapse< targetidentifierT >& +eprop_synapse< targetidentifierT >::operator=( const eprop_synapse& es ) +{ + if ( this == &es ) + { + return *this; + } + + ConnectionBase::operator=( es ); + + weight_ = es.weight_; + t_spike_previous_ = es.t_spike_previous_; + t_previous_trigger_spike_ = es.t_previous_trigger_spike_; + z_bar_ = es.z_bar_; + e_bar_ = es.e_bar_; + e_bar_reg_ = es.e_bar_reg_; + epsilon_ = es.epsilon_; + z_previous_buffer_ = es.z_previous_buffer_; + optimizer_ = es.optimizer_; + + return *this; +} + +template < typename targetidentifierT > +eprop_synapse< targetidentifierT >::eprop_synapse( eprop_synapse&& es ) + : ConnectionBase( es ) + , weight_( es.weight_ ) + , t_spike_previous_( es.t_spike_previous_ ) + , t_previous_trigger_spike_( es.t_spike_previous_ ) + , z_bar_( es.z_bar_ ) + , e_bar_( es.e_bar_ ) + , e_bar_reg_( es.e_bar_reg_ ) + , epsilon_( es.epsilon_ ) + , optimizer_( es.optimizer_ ) +{ + es.optimizer_ = nullptr; +} + +// This assignment operator is used to write a connection into the connection array. +template < typename targetidentifierT > +eprop_synapse< targetidentifierT >& +eprop_synapse< targetidentifierT >::operator=( eprop_synapse&& es ) +{ + if ( this == &es ) + { + return *this; + } + + ConnectionBase::operator=( es ); + + weight_ = es.weight_; + t_spike_previous_ = es.t_spike_previous_; + t_previous_trigger_spike_ = es.t_previous_trigger_spike_; + z_bar_ = es.z_bar_; + e_bar_ = es.e_bar_; + e_bar_reg_ = es.e_bar_reg_; + epsilon_ = es.epsilon_; + z_previous_buffer_ = es.z_previous_buffer_; + + optimizer_ = es.optimizer_; + es.optimizer_ = nullptr; + + return *this; +} + +template < typename targetidentifierT > +inline void +eprop_synapse< targetidentifierT >::check_connection( Node& s, + Node& t, + size_t receptor_type, + const CommonPropertiesType& cp ) +{ + // When we get here, delay has been set so we can check it. + if ( get_delay_steps() != 1 ) + { + throw IllegalConnection( "eprop synapses currently require a delay of one simulation step" ); + } + + ConnTestDummyNode dummy_target; + ConnectionBase::check_connection_( dummy_target, s, t, receptor_type ); + + t.register_eprop_connection( false ); + + optimizer_ = cp.optimizer_cp_->get_optimizer(); +} + +template < typename targetidentifierT > +inline void +eprop_synapse< targetidentifierT >::delete_optimizer() +{ + delete optimizer_; + // do not set to nullptr to allow detection of double deletion +} + +template < typename targetidentifierT > +bool +eprop_synapse< targetidentifierT >::send( Event& e, size_t thread, const EpropSynapseCommonProperties& cp ) +{ + Node* target = get_target( thread ); + assert( target ); + + const long t_spike = e.get_stamp().get_steps(); + + if ( t_spike_previous_ != 0 ) + { + target->compute_gradient( + t_spike, t_spike_previous_, z_previous_buffer_, z_bar_, e_bar_, e_bar_reg_, epsilon_, weight_, cp, optimizer_ ); + } + + const long eprop_isi_trace_cutoff = target->get_eprop_isi_trace_cutoff(); + target->write_update_to_history( t_spike_previous_, t_spike, eprop_isi_trace_cutoff, false ); + + t_spike_previous_ = t_spike; + + e.set_receiver( *target ); + e.set_weight( weight_ ); + e.set_delay_steps( get_delay_steps() ); + e.set_rport( get_rport() ); + e(); + + return true; +} + +template < typename targetidentifierT > +void +eprop_synapse< targetidentifierT >::get_status( DictionaryDatum& d ) const +{ + ConnectionBase::get_status( d ); + def< double >( d, names::weight, weight_ ); + def< long >( d, names::size_of, sizeof( *this ) ); + + DictionaryDatum optimizer_dict = new Dictionary(); + + // The default_connection_ has no optimizer, therefore we need to protect it + if ( optimizer_ ) + { + optimizer_->get_status( optimizer_dict ); + ( *d )[ names::optimizer ] = optimizer_dict; + } +} + +template < typename targetidentifierT > +void +eprop_synapse< targetidentifierT >::set_status( const DictionaryDatum& d, ConnectorModel& cm ) +{ + ConnectionBase::set_status( d, cm ); + if ( d->known( names::optimizer ) ) + { + // We must pass here if called by SetDefaults. In that case, the user will get and error + // message because the parameters for the synapse-specific optimizer have not been accessed. + if ( optimizer_ ) + { + optimizer_->set_status( getValue< DictionaryDatum >( d->lookup( names::optimizer ) ) ); + } + } + + updateValue< double >( d, names::weight, weight_ ); + + const auto& gcm = dynamic_cast< const GenericConnectorModel< eprop_synapse< targetidentifierT > >& >( cm ); + const CommonPropertiesType& epcp = gcm.get_common_properties(); + if ( weight_ < epcp.optimizer_cp_->get_Wmin() ) + { + throw BadProperty( "Minimal weight Wmin ≤ weight required." ); + } + + if ( weight_ > epcp.optimizer_cp_->get_Wmax() ) + { + throw BadProperty( "weight ≤ maximal weight Wmax required." ); + } +} + +} // namespace nest + +#endif // EPROP_SYNAPSE_H diff --git a/models/eprop_synapse_bsshslm_2020.h b/models/eprop_synapse_bsshslm_2020.h index 52223ca67b..8755caf0a3 100644 --- a/models/eprop_synapse_bsshslm_2020.h +++ b/models/eprop_synapse_bsshslm_2020.h @@ -88,37 +88,52 @@ Parameters The following parameters can be set in the status dictionary. -================ ======= =============== ======= ====================================================== -**Common synapse parameters** -------------------------------------------------------------------------------------------------------- -Parameter Unit Math equivalent Default Description -================ ======= =============== ======= ====================================================== -average_gradient Boolean False If True, average the gradient over the learning window -optimizer {} Dictionary of optimizer parameters -================ ======= =============== ======= ====================================================== - -============= ==== ========================= ======= ========================================================= +==================== ======= =============== ========= ====================================================== +**Common e-prop synapse parameters** +------------------------------------------------------------------------------------------------------------- +Parameter Unit Math equivalent Default Description +==================== ======= =============== ========= ====================================================== +``average_gradient`` Boolean ``False`` If ``True``, average the gradient over the learning + window +``optimizer`` {} Dictionary of optimizer parameters +==================== ======= =============== ========= ====================================================== + +============= ==== ========================= ======= ================================ **Individual synapse parameters** --------------------------------------------------------------------------------------------------------------- +------------------------------------------------------------------------------------- Parameter Unit Math equivalent Default Description -============= ==== ========================= ======= ========================================================= -delay ms :math:`d_{ji}` 1.0 Dendritic delay -tau_m_readout ms :math:`\tau_\text{m,out}` 10.0 Time constant for low-pass filtering of eligibility trace -weight pA :math:`W_{ji}` 1.0 Initial value of synaptic weight -============= ==== ========================= ======= ========================================================= +============= ==== ========================= ======= ================================ +``delay`` ms :math:`d_{ji}` 1.0 Dendritic delay +``weight`` pA :math:`W_{ji}` 1.0 Initial value of synaptic weight +============= ==== ========================= ======= ================================ + +================= ==== ========================= ======= ============================== +**Individual e-prop synapse parameters** +--------------------------------------------------------------------------------------- +Parameter Unit Math equivalent Default Description +================= ==== ========================= ======= ============================== +``tau_m_readout`` ms :math:`\tau_\text{m,out}` 10.0 Time constant for low-pass + filtering of eligibility trace +================= ==== ========================= ======= ============================== Recordables +++++++++++ The following variables can be recorded. - - synaptic weight ``weight`` +================== ==== =============== ============= =============== +**Synapse recordables** +--------------------------------------------------------------------- +State variable Unit Math equivalent Initial value Description +================== ==== =============== ============= =============== +``weight`` pA :math:`B_{jk}` 1.0 Synaptic weight +================== ==== =============== ============= =============== Usage +++++ -This model can only be used in combination with the other e-prop models, -whereby the network architecture requires specific wiring, input, and output. +This model can only be used in combination with the other e-prop models +and the network architecture requires specific wiring, input, and output. The usage is demonstrated in several :doc:`supervised regression and classification tasks <../auto_examples/eprop_plasticity/index>` reproducing among others the original proof-of-concept tasks in [1]_. @@ -136,22 +151,24 @@ References networks of spiking neurons. Nature Communications, 11:3625. https://doi.org/10.1038/s41467-020-17236-y -.. [2] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Dahmen D, - van Albada SJ, Bolten M, Diesmann M. Event-based implementation of - eligibility propagation (in preparation) +.. [2] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Plesser HE, + Dahmen D, Bolten M, Van Albada SJ*, Diesmann M*. Event-based + implementation of eligibility propagation (in preparation) See also ++++++++ Examples using this model -++++++++++++++++++++++++++ ++++++++++++++++++++++++++ .. listexamples:: eprop_synapse_bsshslm_2020 EndUserDocs */ /** - * Base class implementing common properties for the e-prop synapse model. + * @brief Base class implementing common properties for e-prop synapses. + * + * Base class implementing common properties for the e-prop synapse model according to Bellec et al. (2020). * * This class in particular manages a pointer to weight-optimizer common properties to support * exchanging the weight optimizer at runtime. Setting the weight-optimizer common properties @@ -197,9 +214,12 @@ class EpropSynapseBSSHSLM2020CommonProperties : public CommonSynapseProperties void register_eprop_synapse_bsshslm_2020( const std::string& name ); /** + * @brief Class implementing a synapse model for e-prop plasticity. + * * Class implementing a synapse model for e-prop plasticity according to Bellec et al. (2020). * - * @note Several aspects of this synapse are in place to reproduce the Tensorflow implementation of Bellec et al (2020). + * @note Several aspects of this synapse are in place to reproduce the Tensorflow implementation of Bellec et al. + * (2020). * * @note Each synapse has a optimizer_ object managed through a `WeightOptimizer*`, pointing to an object of * a specific weight optimizer type. This optimizer, drawing also on parameters in the `WeightOptimizerCommonProperties` @@ -312,7 +332,7 @@ class eprop_synapse_bsshslm_2020 : public Connection< targetidentifierT > double weight_; //! The time step when the previous spike arrived. - long t_previous_spike_; + long t_spike_previous_; //! The time step when the previous e-prop update was. long t_previous_update_; @@ -364,7 +384,7 @@ template < typename targetidentifierT > eprop_synapse_bsshslm_2020< targetidentifierT >::eprop_synapse_bsshslm_2020() : ConnectionBase() , weight_( 1.0 ) - , t_previous_spike_( 0 ) + , t_spike_previous_( 0 ) , t_previous_update_( 0 ) , t_next_update_( 0 ) , t_previous_trigger_spike_( 0 ) @@ -386,7 +406,7 @@ template < typename targetidentifierT > eprop_synapse_bsshslm_2020< targetidentifierT >::eprop_synapse_bsshslm_2020( const eprop_synapse_bsshslm_2020& es ) : ConnectionBase( es ) , weight_( es.weight_ ) - , t_previous_spike_( 0 ) + , t_spike_previous_( 0 ) , t_previous_update_( 0 ) , t_next_update_( kernel().simulation_manager.get_eprop_update_interval().get_steps() ) , t_previous_trigger_spike_( 0 ) @@ -397,7 +417,7 @@ eprop_synapse_bsshslm_2020< targetidentifierT >::eprop_synapse_bsshslm_2020( con { } -// This assignment operator is used to write a connection into the connection array. +// This copy assignment operator is used to write a connection into the connection array. template < typename targetidentifierT > eprop_synapse_bsshslm_2020< targetidentifierT >& eprop_synapse_bsshslm_2020< targetidentifierT >::operator=( const eprop_synapse_bsshslm_2020& es ) @@ -410,7 +430,7 @@ eprop_synapse_bsshslm_2020< targetidentifierT >::operator=( const eprop_synapse_ ConnectionBase::operator=( es ); weight_ = es.weight_; - t_previous_spike_ = es.t_previous_spike_; + t_spike_previous_ = es.t_spike_previous_; t_previous_update_ = es.t_previous_update_; t_next_update_ = es.t_next_update_; t_previous_trigger_spike_ = es.t_previous_trigger_spike_; @@ -426,7 +446,7 @@ template < typename targetidentifierT > eprop_synapse_bsshslm_2020< targetidentifierT >::eprop_synapse_bsshslm_2020( eprop_synapse_bsshslm_2020&& es ) : ConnectionBase( es ) , weight_( es.weight_ ) - , t_previous_spike_( 0 ) + , t_spike_previous_( 0 ) , t_previous_update_( 0 ) , t_next_update_( es.t_next_update_ ) , t_previous_trigger_spike_( 0 ) @@ -438,7 +458,7 @@ eprop_synapse_bsshslm_2020< targetidentifierT >::eprop_synapse_bsshslm_2020( epr es.optimizer_ = nullptr; } -// This assignment operator is used to write a connection into the connection array. +// This move assignment operator is used to write a connection into the connection array. template < typename targetidentifierT > eprop_synapse_bsshslm_2020< targetidentifierT >& eprop_synapse_bsshslm_2020< targetidentifierT >::operator=( eprop_synapse_bsshslm_2020&& es ) @@ -451,7 +471,7 @@ eprop_synapse_bsshslm_2020< targetidentifierT >::operator=( eprop_synapse_bsshsl ConnectionBase::operator=( es ); weight_ = es.weight_; - t_previous_spike_ = es.t_previous_spike_; + t_spike_previous_ = es.t_spike_previous_; t_previous_update_ = es.t_previous_update_; t_next_update_ = es.t_next_update_; t_previous_trigger_spike_ = es.t_previous_trigger_spike_; @@ -519,10 +539,10 @@ eprop_synapse_bsshslm_2020< targetidentifierT >::send( Event& e, t_previous_trigger_spike_ = t_spike; } - if ( t_previous_spike_ > 0 ) + if ( t_spike_previous_ > 0 ) { const long t = t_spike >= t_next_update_ + shift ? t_next_update_ + shift : t_spike; - presyn_isis_.push_back( t - t_previous_spike_ ); + presyn_isis_.push_back( t - t_spike_previous_ ); } if ( t_spike > t_next_update_ + shift ) @@ -543,7 +563,7 @@ eprop_synapse_bsshslm_2020< targetidentifierT >::send( Event& e, t_previous_trigger_spike_ = t_spike; } - t_previous_spike_ = t_spike; + t_spike_previous_ = t_spike; e.set_receiver( *target ); e.set_weight( weight_ ); @@ -582,6 +602,13 @@ eprop_synapse_bsshslm_2020< targetidentifierT >::set_status( const DictionaryDat { // We must pass here if called by SetDefaults. In that case, the user will get and error // message because the parameters for the synapse-specific optimizer have not been accessed. + auto optimizer_dict = getValue< DictionaryDatum >( d->lookup( names::optimizer ) ); + auto it = optimizer_dict->find( names::optimize_each_step ); + if ( it != optimizer_dict->end() ) + { + throw BadProperty( + "eprop_synapse_bsshslm_2020 only supports optimization in each step optimize_each_step == False." ); + } if ( optimizer_ ) { optimizer_->set_status( getValue< DictionaryDatum >( d->lookup( names::optimizer ) ) ); diff --git a/models/weight_optimizer.cpp b/models/weight_optimizer.cpp index db0a07fedc..eed7dd3dad 100644 --- a/models/weight_optimizer.cpp +++ b/models/weight_optimizer.cpp @@ -34,16 +34,22 @@ namespace nest WeightOptimizerCommonProperties::WeightOptimizerCommonProperties() : batch_size_( 1 ) , eta_( 1e-4 ) + , eta_first_( 1e-4 ) + , n_eta_change_( 0 ) , Wmin_( -100.0 ) , Wmax_( 100.0 ) + , optimize_each_step_( true ) { } WeightOptimizerCommonProperties::WeightOptimizerCommonProperties( const WeightOptimizerCommonProperties& cp ) : batch_size_( cp.batch_size_ ) , eta_( cp.eta_ ) + , eta_first_( cp.eta_first_ ) + , n_eta_change_( cp.n_eta_change_ ) , Wmin_( cp.Wmin_ ) , Wmax_( cp.Wmax_ ) + , optimize_each_step_( cp.optimize_each_step_ ) { } @@ -55,6 +61,7 @@ WeightOptimizerCommonProperties::get_status( DictionaryDatum& d ) const def< double >( d, names::eta, eta_ ); def< double >( d, names::Wmin, Wmin_ ); def< double >( d, names::Wmax, Wmax_ ); + def< bool >( d, names::optimize_each_step, optimize_each_step_ ); } void @@ -74,6 +81,16 @@ WeightOptimizerCommonProperties::set_status( const DictionaryDatum& d ) { throw BadProperty( "Learning rate eta ≥ 0 required." ); } + + if ( new_eta != eta_ ) + { + if ( n_eta_change_ == 0 ) + { + eta_first_ = new_eta; + } + n_eta_change_ += 1; + } + eta_ = new_eta; double new_Wmin = Wmin_; @@ -86,11 +103,15 @@ WeightOptimizerCommonProperties::set_status( const DictionaryDatum& d ) } Wmin_ = new_Wmin; Wmax_ = new_Wmax; + + updateValue< bool >( d, names::optimize_each_step, optimize_each_step_ ); } WeightOptimizer::WeightOptimizer() : sum_gradients_( 0.0 ) , optimization_step_( 1 ) + , eta_( 1e-4 ) + , n_optimize_( 0 ) { } @@ -105,18 +126,29 @@ WeightOptimizer::set_status( const DictionaryDatum& d ) } double -WeightOptimizer::optimized_weight( const WeightOptimizerCommonProperties& cp, +WeightOptimizer::optimized_weight( WeightOptimizerCommonProperties& cp, const size_t idx_current_update, const double gradient, double weight ) { + if ( cp.n_eta_change_ != 0 and n_optimize_ == 0 ) + { + eta_ = cp.eta_first_; + } sum_gradients_ += gradient; + if ( optimization_step_ == 0 ) + { + optimization_step_ = idx_current_update; + } + const size_t current_optimization_step = 1 + idx_current_update / cp.batch_size_; if ( optimization_step_ < current_optimization_step ) { sum_gradients_ /= cp.batch_size_; weight = std::max( cp.Wmin_, std::min( optimize_( cp, weight, current_optimization_step ), cp.Wmax_ ) ); + eta_ = cp.eta_; + n_optimize_ += 1; optimization_step_ = current_optimization_step; } return weight; @@ -142,8 +174,8 @@ WeightOptimizerGradientDescent::WeightOptimizerGradientDescent() double WeightOptimizerGradientDescent::optimize_( const WeightOptimizerCommonProperties& cp, double weight, size_t ) { - weight -= cp.eta_ * sum_gradients_; - sum_gradients_ = 0; + weight -= eta_ * sum_gradients_; + sum_gradients_ = 0.0; return weight; } @@ -151,7 +183,7 @@ WeightOptimizerCommonPropertiesAdam::WeightOptimizerCommonPropertiesAdam() : WeightOptimizerCommonProperties() , beta_1_( 0.9 ) , beta_2_( 0.999 ) - , epsilon_( 1e-8 ) + , epsilon_( 1e-7 ) { } @@ -207,6 +239,8 @@ WeightOptimizerAdam::WeightOptimizerAdam() : WeightOptimizer() , m_( 0.0 ) , v_( 0.0 ) + , beta_1_power_( 1.0 ) + , beta_2_power_( 1.0 ) { } @@ -236,10 +270,10 @@ WeightOptimizerAdam::optimize_( const WeightOptimizerCommonProperties& cp, for ( ; optimization_step_ < current_optimization_step; ++optimization_step_ ) { - const double beta_1_factor = 1.0 - std::pow( acp.beta_1_, optimization_step_ ); - const double beta_2_factor = 1.0 - std::pow( acp.beta_2_, optimization_step_ ); + beta_1_power_ *= acp.beta_1_; + beta_2_power_ *= acp.beta_2_; - const double alpha = cp.eta_ * std::sqrt( beta_2_factor ) / beta_1_factor; + const double alpha = eta_ * std::sqrt( 1.0 - beta_2_power_ ) / ( 1.0 - beta_1_power_ ); m_ = acp.beta_1_ * m_ + ( 1.0 - acp.beta_1_ ) * sum_gradients_; v_ = acp.beta_2_ * v_ + ( 1.0 - acp.beta_2_ ) * sum_gradients_ * sum_gradients_; diff --git a/models/weight_optimizer.h b/models/weight_optimizer.h index 9cacba0745..d14205fcb3 100644 --- a/models/weight_optimizer.h +++ b/models/weight_optimizer.h @@ -29,7 +29,7 @@ namespace nest { -/* BeginUserDocs: e-prop plasticity +/* BeginUserDocs: e-prop plasticity, synapse Short description +++++++++++++++++ @@ -49,55 +49,69 @@ Currently two weight optimizers are implemented: gradient descent and the Adam o In gradient descent [1]_ the weights are optimized via: .. math:: - W_t = W_{t-1} - \eta \, g_t \,, + W_t = W_{t-1} - \eta g_t \,, \\ -whereby :math:`\eta` denotes the learning rate and :math:`g_t` the gradient of the current +where :math:`\eta` denotes the learning rate and :math:`g_t` the gradient of the current time step :math:`t`. In the Adam scheme [2]_ the weights are optimized via: .. math:: m_0 &= 0, v_0 = 0, t = 1 \,, \\ - m_t &= \beta_1 \, m_{t-1} + \left(1-\beta_1\right) \, g_t \,, \\ - v_t &= \beta_2 \, v_{t-1} + \left(1-\beta_2\right) \, g_t^2 \,, \\ - \hat{m}_t &= \frac{m_t}{1-\beta_1^t} \,, \\ - \hat{v}_t &= \frac{v_t}{1-\beta_2^t} \,, \\ - W_t &= W_{t-1} - \eta\frac{\hat{m_t}}{\sqrt{\hat{v}_t} + \epsilon} \,. + m_t &= \beta_1 m_{t-1} + \left( 1- \beta_1 \right) g_t \,, \\ + v_t &= \beta_2 v_{t-1} + \left( 1 - \beta_2 \right) g_t^2 \,, \\ + \alpha_t &= \eta \frac{ \sqrt{ 1- \beta_2^t } }{ 1 - \beta_1^t } \,, \\ + W_t &= W_{t-1} - \alpha_t \frac{ m_t }{ \sqrt{v_t} + \hat{\epsilon} } \,. \\ + +Note that the implementation follows the implementation in TensorFlow [3]_ for comparability. +The TensorFlow implementation deviates from [1]_ in that it assumes +:math:`\hat{\epsilon} = \epsilon \sqrt{ 1 - \beta_2^t }` to be constant, whereas [1]_ +assumes :math:`\epsilon = \hat{\epsilon} \sqrt{ 1 - \beta_2^t }` to be constant. + +When `optimize_each_step` is set to `True`, the weights are optimized at every +time step. If set to `False`, optimization occurs once per spike, resulting in a +significant speed-up. For gradient descent, both settings yield the same +results under exact arithmetic; however, small numerical differences may be +observed due to floating point precision. For the Adam optimizer, only setting +`optimize_each_step` to `True` precisely implements the algorithm as described +in [2]_. The impact of this setting on learning performance may vary depending +on the task. Parameters ++++++++++ The following parameters can be set in the status dictionary. -========== ==== ========================= ======= ================================= +====================== ==== ========================= ========= ================================= **Common optimizer parameters** ------------------------------------------------------------------------------------ -Parameter Unit Math equivalent Default Description -========== ==== ========================= ======= ================================= -batch_size 1 Size of batch -eta :math:`\eta` 1e-4 Learning rate -Wmax pA :math:`W_{ji}^\text{max}` 100.0 Maximal value for synaptic weight -Wmin pA :math:`W_{ji}^\text{min}` -100.0 Minimal value for synaptic weight -========== ==== ========================= ======= ================================= - -========= ==== =============== ================ ============== +------------------------------------------------------------------------------------------------- +Parameter Unit Math equivalent Default Description +====================== ==== ========================= ========= ================================= +``batch_size`` 1 Size of batch +``eta`` :math:`\eta` 1e-4 Learning rate +``optimize_each_step`` ``True`` +``Wmax`` pA :math:`W_{ji}^\text{max}` 100.0 Maximal value for synaptic weight +``Wmin`` pA :math:`W_{ji}^\text{min}` -100.0 Minimal value for synaptic weight +====================== ==== ========================= ========= ================================= + +========= ==== =============== ================== ============== **Gradient descent parameters (default optimizer)** --------------------------------------------------------------- -Parameter Unit Math equivalent Default Description -========= ==== =============== ================ ============== -type gradient_descent Optimizer type -========= ==== =============== ================ ============== +---------------------------------------------------------------- +Parameter Unit Math equivalent Default Description +========= ==== =============== ================== ============== +``type`` "gradient_descent" Optimizer type +========= ==== =============== ================== ============== -========= ==== ================ ======= ================================================= +=========== ==== ================ ======= ================================================= **Adam optimizer parameters** ------------------------------------------------------------------------------------------ -Parameter Unit Math equivalent Default Description -========= ==== ================ ======= ================================================= -type adam Optimizer type -beta_1 :math:`\beta_1` 0.9 Exponential decay rate for first moment estimate -beta_2 :math:`\beta_2` 0.999 Exponential decay rate for second moment estimate -epsilon :math:`\epsilon` 1e-8 Small constant for numerical stability -========= ==== ================ ======= ================================================= +------------------------------------------------------------------------------------------- +Parameter Unit Math equivalent Default Description +=========== ==== ================ ======= ================================================= +``type`` "adam" Optimizer type +``beta_1`` :math:`\beta_1` 0.9 Exponential decay rate for first moment estimate +``beta_2`` :math:`\beta_2` 0.999 Exponential decay rate for second moment estimate +``epsilon`` :math:`\epsilon` 1e-7 Small constant for numerical stability +=========== ==== ================ ======= ================================================= The following state variables evolve during simulation. @@ -106,24 +120,29 @@ The following state variables evolve during simulation. ---------------------------------------------------------------------------- State variable Unit Math equivalent Initial value Description ============== ==== =============== ============= ========================== -m :math:`m` 0.0 First moment estimate -v :math:`v` 0.0 Second moment raw estimate +``m`` :math:`m` 0.0 First moment estimate +``v`` :math:`v` 0.0 Second moment raw estimate ============== ==== =============== ============= ========================== References ++++++++++ -.. [1] Huh, D. & Sejnowski, T. J. Gradient descent for spiking neural networks. 32nd - Conference on Neural Information Processing Systems (2018). + +.. [1] Huh D, Sejnowski TJ (2018). Gradient descent for spiking neural networks. + Advances in Neural Information Processing Systems, 31:1433-1443. + https://proceedings.neurips.cc/paper_files/paper/2018/hash/185e65bc40581880c4f2c82958de8cfe-Abstract.html + .. [2] Kingma DP, Ba JL (2015). Adam: A method for stochastic optimization. - Proceedings of International Conference on Learning Representations (ICLR). + Proceedings of 3rd International Conference for Learning Representations (ICLR). https://doi.org/10.48550/arXiv.1412.6980 +.. [3] https://github.com/keras-team/keras/blob/v2.15.0/keras/optimizers/adam.py#L26-L220 + See also ++++++++ Examples using this model -++++++++++++++++++++++++++ ++++++++++++++++++++++++++ .. listexamples:: eprop_synapse_bsshslm_2020 @@ -188,14 +207,23 @@ class WeightOptimizerCommonProperties //! Size of an optimization batch. size_t batch_size_; - //! Learning rate. + //! Learning rate common to all synapses. double eta_; + //! First learning rate that differs from the default. + double eta_first_; + + //! Number of changes to the learning rate. + long n_eta_change_; + //! Minimal value for synaptic weight. double Wmin_; //! Maximal value for synaptic weight. double Wmax_; + + //! If true, optimize each step, else once per spike. + bool optimize_each_step_; }; /** @@ -230,7 +258,7 @@ class WeightOptimizer virtual void set_status( const DictionaryDatum& d ); //! Return optimized weight based on current weight. - double optimized_weight( const WeightOptimizerCommonProperties& cp, + double optimized_weight( WeightOptimizerCommonProperties& cp, const size_t idx_current_update, const double gradient, double weight ); @@ -244,6 +272,12 @@ class WeightOptimizer //! Current optimization step, whereby optimization happens every batch_size_ steps. size_t optimization_step_; + + //! Learning rate private to the synapse. + double eta_; + + //! Number of optimizations. + long n_optimize_; }; /** @@ -314,6 +348,12 @@ class WeightOptimizerAdam : public WeightOptimizer //! Second moment estimate variable. double v_; + + //! Power of beta_1 factor. + double beta_1_power_; + + //! Power of beta_2 factor. + double beta_2_power_; }; /** diff --git a/modelsets/eprop b/modelsets/eprop index c373835f6b..0a9a5ed696 100644 --- a/modelsets/eprop +++ b/modelsets/eprop @@ -1,19 +1,27 @@ -# Minimal modelset for spiking neuron simulations with e-prop plasticity +# Minimal model set for spiking neuron simulations with e-prop plasticity multimeter spike_recorder weight_recorder +parrot_neuron +poisson_generator spike_generator step_rate_generator -poisson_generator + +rate_connection_delayed +static_synapse eprop_iaf_bsshslm_2020 eprop_iaf_adapt_bsshslm_2020 eprop_readout_bsshslm_2020 -parrot_neuron - -eprop_learning_signal_connection_bsshslm_2020 eprop_synapse_bsshslm_2020 -rate_connection_delayed -static_synapse +eprop_learning_signal_connection_bsshslm_2020 + +eprop_iaf +eprop_iaf_adapt +eprop_iaf_psc_delta +eprop_iaf_psc_delta_adapt +eprop_readout +eprop_synapse +eprop_learning_signal_connection diff --git a/modelsets/full b/modelsets/full index 2d8bbf461c..46d6c6ee5a 100644 --- a/modelsets/full +++ b/modelsets/full @@ -26,6 +26,13 @@ eprop_iaf_adapt_bsshslm_2020 eprop_readout_bsshslm_2020 eprop_synapse_bsshslm_2020 eprop_learning_signal_connection_bsshslm_2020 +eprop_iaf +eprop_iaf_adapt +eprop_iaf_psc_delta +eprop_iaf_psc_delta_adapt +eprop_readout +eprop_synapse +eprop_learning_signal_connection erfc_neuron gamma_sup_generator gap_junction diff --git a/nestkernel/eprop_archiving_node.cpp b/nestkernel/eprop_archiving_node.cpp index f607ca34ec..c66de6c3fa 100644 --- a/nestkernel/eprop_archiving_node.cpp +++ b/nestkernel/eprop_archiving_node.cpp @@ -31,18 +31,124 @@ namespace nest { +std::map< std::string, EpropArchivingNodeRecurrent::surrogate_gradient_function > + EpropArchivingNodeRecurrent::surrogate_gradient_funcs_ = { + { "piecewise_linear", &EpropArchivingNodeRecurrent::compute_piecewise_linear_surrogate_gradient }, + { "exponential", &EpropArchivingNodeRecurrent::compute_exponential_surrogate_gradient }, + { "fast_sigmoid_derivative", &EpropArchivingNodeRecurrent::compute_fast_sigmoid_derivative_surrogate_gradient }, + { "arctan", &EpropArchivingNodeRecurrent::compute_arctan_surrogate_gradient } + }; + + EpropArchivingNodeRecurrent::EpropArchivingNodeRecurrent() : EpropArchivingNode() + , firing_rate_reg_( 0.0 ) + , f_av_( 0.0 ) , n_spikes_( 0 ) { } EpropArchivingNodeRecurrent::EpropArchivingNodeRecurrent( const EpropArchivingNodeRecurrent& n ) : EpropArchivingNode( n ) + , firing_rate_reg_( n.firing_rate_reg_ ) + , f_av_( n.f_av_ ) , n_spikes_( n.n_spikes_ ) { } +EpropArchivingNodeRecurrent::surrogate_gradient_function +EpropArchivingNodeRecurrent::select_surrogate_gradient( const std::string& surrogate_gradient_function_name ) +{ + const auto found_entry_it = surrogate_gradient_funcs_.find( surrogate_gradient_function_name ); + + if ( found_entry_it != surrogate_gradient_funcs_.end() ) + { + return found_entry_it->second; + } + + std::string error_message = "Surrogate gradient / pseudo-derivate function surrogate_gradient_function from ["; + for ( const auto& surrogate_gradient_func : surrogate_gradient_funcs_ ) + { + error_message += " \"" + surrogate_gradient_func.first + "\","; + } + error_message.pop_back(); + error_message += " ] required."; + + throw BadProperty( error_message ); +} + + +double +EpropArchivingNodeRecurrent::compute_piecewise_linear_surrogate_gradient( const double r, + const double v_m, + const double v_th, + const double beta, + const double gamma ) +{ + if ( r > 0 ) + { + return 0.0; + } + + return gamma * std::max( 0.0, 1.0 - beta * std::abs( ( v_m - v_th ) ) ); +} + +double +EpropArchivingNodeRecurrent::compute_exponential_surrogate_gradient( const double r, + const double v_m, + const double v_th, + const double beta, + const double gamma ) +{ + if ( r > 0 ) + { + return 0.0; + } + + return gamma * std::exp( -beta * std::abs( v_m - v_th ) ); +} + +double +EpropArchivingNodeRecurrent::compute_fast_sigmoid_derivative_surrogate_gradient( const double r, + const double v_m, + const double v_th, + const double beta, + const double gamma ) +{ + if ( r > 0 ) + { + return 0.0; + } + + return gamma * std::pow( 1.0 + beta * std::abs( v_m - v_th ), -2 ); +} + +double +EpropArchivingNodeRecurrent::compute_arctan_surrogate_gradient( const double r, + const double v_m, + const double v_th, + const double beta, + const double gamma ) +{ + if ( r > 0 ) + { + return 0.0; + } + + return gamma / M_PI * ( 1.0 / ( 1.0 + std::pow( beta * M_PI * ( v_m - v_th ), 2 ) ) ); +} + +void +EpropArchivingNodeRecurrent::append_new_eprop_history_entry( const long time_step ) +{ + if ( eprop_indegree_ == 0 ) + { + return; + } + + eprop_history_.emplace_back( time_step, 0.0, 0.0, 0.0 ); +} + void EpropArchivingNodeRecurrent::write_surrogate_gradient_to_history( const long time_step, const double surrogate_gradient ) @@ -52,18 +158,27 @@ EpropArchivingNodeRecurrent::write_surrogate_gradient_to_history( const long tim return; } - eprop_history_.emplace_back( time_step, surrogate_gradient, 0.0 ); + auto it_hist = get_eprop_history( time_step ); + it_hist->surrogate_gradient_ = surrogate_gradient; } void -EpropArchivingNodeRecurrent::write_learning_signal_to_history( const long time_step, const double learning_signal ) +EpropArchivingNodeRecurrent::write_learning_signal_to_history( const long time_step, + const double learning_signal, + const bool has_norm_step ) { if ( eprop_indegree_ == 0 ) { return; } - const long shift = delay_rec_out_ + delay_out_norm_ + delay_out_rec_; + long shift = delay_rec_out_ + delay_out_rec_; + + if ( has_norm_step ) + { + shift += delay_out_norm_; + } + auto it_hist = get_eprop_history( time_step - shift ); const auto it_hist_end = get_eprop_history( time_step - shift + delay_out_rec_ ); @@ -95,19 +210,48 @@ EpropArchivingNodeRecurrent::write_firing_rate_reg_to_history( const long t_curr firing_rate_reg_history_.emplace_back( t_current_update + shift, firing_rate_reg ); } -std::vector< HistEntryEpropFiringRateReg >::iterator +void +EpropArchivingNodeRecurrent::write_firing_rate_reg_to_history( const long time_step, + const double z, + const double f_target, + const double kappa_reg, + const double c_reg ) +{ + if ( eprop_indegree_ == 0 ) + { + return; + } + + const double dt = Time::get_resolution().get_ms(); + + const double f_target_ = f_target * dt; // convert from spikes/ms to spikes/step + + f_av_ = kappa_reg * f_av_ + ( 1.0 - kappa_reg ) * z / dt; + + firing_rate_reg_ = c_reg * ( f_av_ - f_target_ ); + + auto it_hist = get_eprop_history( time_step ); + it_hist->firing_rate_reg_ = firing_rate_reg_; +} + +double EpropArchivingNodeRecurrent::get_firing_rate_reg_history( const long time_step ) { const auto it_hist = std::lower_bound( firing_rate_reg_history_.begin(), firing_rate_reg_history_.end(), time_step ); assert( it_hist != firing_rate_reg_history_.end() ); - return it_hist; + return it_hist->firing_rate_reg_; } double -EpropArchivingNodeRecurrent::get_learning_signal_from_history( const long time_step ) +EpropArchivingNodeRecurrent::get_learning_signal_from_history( const long time_step, const bool has_norm_step ) { - const long shift = delay_rec_out_ + delay_out_norm_ + delay_out_rec_; + long shift = delay_rec_out_ + delay_out_rec_; + + if ( has_norm_step ) + { + shift += delay_out_norm_; + } const auto it = get_eprop_history( time_step - shift ); if ( it == eprop_history_.end() ) @@ -149,16 +293,32 @@ EpropArchivingNodeReadout::EpropArchivingNodeReadout( const EpropArchivingNodeRe } void -EpropArchivingNodeReadout::write_error_signal_to_history( const long time_step, const double error_signal ) +EpropArchivingNodeReadout::append_new_eprop_history_entry( const long time_step, const bool has_norm_step ) { if ( eprop_indegree_ == 0 ) { return; } - const long shift = delay_out_norm_; + const long shift = has_norm_step ? delay_out_norm_ : 0; - eprop_history_.emplace_back( time_step - shift, error_signal ); + eprop_history_.emplace_back( time_step - shift, 0.0 ); +} + +void +EpropArchivingNodeReadout::write_error_signal_to_history( const long time_step, + const double error_signal, + const bool has_norm_step ) +{ + if ( eprop_indegree_ == 0 ) + { + return; + } + + const long shift = has_norm_step ? delay_out_norm_ : 0; + + auto it_hist = get_eprop_history( time_step - shift ); + it_hist->error_signal_ = error_signal; } diff --git a/nestkernel/eprop_archiving_node.h b/nestkernel/eprop_archiving_node.h index 187b22341b..f020b9ec0a 100644 --- a/nestkernel/eprop_archiving_node.h +++ b/nestkernel/eprop_archiving_node.h @@ -34,49 +34,81 @@ namespace nest { - /** - * Base class implementing an intermediate archiving node model for node models supporting e-prop plasticity. + * @brief Base class implementing archiving for node models supporting e-prop plasticity. + * + * Base class implementing an intermediate archiving node model for node models supporting e-prop plasticity + * according to Bellec et al. (2020) and supporting additional biological features described in Korcsak-Gorzo, + * Stapmanns, and Espinoza Valverde et al. (in preparation). * * A node which archives the history of dynamic variables, the firing rate * regularization, and update times needed to calculate the weight updates for * e-prop plasticity. It further provides a set of get, write, and set functions * for these histories and the hardcoded shifts to synchronize the factors of * the plasticity rule. + * + * @tparam HistEntryT The type of history entry. */ template < typename HistEntryT > class EpropArchivingNode : public Node { public: - //! Default constructor. + /** + * Constructs a new EpropArchivingNode object. + */ EpropArchivingNode(); - //! Copy constructor. - EpropArchivingNode( const EpropArchivingNode& ); - - //! Initialize the update history and register the eprop synapse. - void register_eprop_connection() override; - - //! Register current update in the update history and deregister previous update. - void write_update_to_history( const long t_previous_update, const long t_current_update ) override; - - //! Get an iterator pointing to the update history entry of the given time step. + /** + * Constructs a new EpropArchivingNode object by copying another EpropArchivingNode object. + * + * @param other The other object to copy. + */ + EpropArchivingNode( const EpropArchivingNode& other ); + + void register_eprop_connection( const bool is_bsshslm_2020_model = true ) override; + + void write_update_to_history( const long t_previous_update, + const long t_current_update, + const long eprop_isi_trace_cutoff = 0, + const bool erase = false ) override; + + /** + * Retrieves the update history entry for a specific time step. + * + * @param time_step The time step. + * @return An iterator pointing to the update history for the specified time step. + */ std::vector< HistEntryEpropUpdate >::iterator get_update_history( const long time_step ); - //! Get an iterator pointing to the eprop history entry of the given time step. + /** + * Retrieves the eprop history entry for a specified time step. + * + * @param time_step The time step. + * @return An iterator pointing to the eprop history entry for the specified time step. + */ typename std::vector< HistEntryT >::iterator get_eprop_history( const long time_step ); - //! Erase update history parts for which the access counter has decreased to zero since no synapse needs them - //! any longer. - void erase_used_update_history(); - - //! Erase update intervals from the e-prop history in which each synapse has either not transmitted a spike or has - //! transmitted a spike in a more recent update interval. + /** + * @brief Erases the used eprop history for `bsshslm_2020` models. + * + * Erases e-prop history entries for update intervals during which no spikes were sent to the target neuron, + * and any entries older than the earliest time stamp required by the first update in the history. + */ void erase_used_eprop_history(); + /** + * @brief Erases the used eprop history. + * + * Erases e-prop history entries between the last and penultimate updates if they exceed the inter-spike + * interval trace cutoff and any entries older than the earliest time stamp required by the first update. + * + * @param eprop_isi_trace_cutoff The cutoff value for the inter-spike integration of the eprop trace. + */ + void erase_used_eprop_history( const long eprop_isi_trace_cutoff ); + protected: - //!< Number of incoming eprop synapses + //! Number of incoming eprop synapses size_t eprop_indegree_; //! History of updates still needed by at least one synapse. @@ -91,16 +123,16 @@ class EpropArchivingNode : public Node //! Offset since generator signals start from time step 1. const long offset_gen_ = 1; - //! Connection delay from input to recurrent neurons. + //! Transmission delay from input to recurrent neurons. const long delay_in_rec_ = 1; - //! Connection delay from recurrent to output neurons. + //! Transmission delay from recurrent to output neurons. const long delay_rec_out_ = 1; - //! Connection delay between output neurons for normalization. + //! Transmission delay between output neurons for normalization. const long delay_out_norm_ = 1; - //! Connection delay from output neurons to recurrent neurons. + //! Transmission delay from output neurons to recurrent neurons. const long delay_out_rec_ = 1; }; @@ -111,44 +143,225 @@ class EpropArchivingNodeRecurrent : public EpropArchivingNode< HistEntryEpropRec { public: - //! Default constructor. + /** + * Constructs a new EpropArchivingNodeRecurrent object. + */ EpropArchivingNodeRecurrent(); - //! Copy constructor. - EpropArchivingNodeRecurrent( const EpropArchivingNodeRecurrent& ); - - //! Create an entry in the eprop history for the given time step and surrogate gradient. + /** + * Constructs an EpropArchivingNodeRecurrent object by copying another EpropArchivingNodeRecurrent object. + * + * @param other The EpropArchivingNodeRecurrent object to copy. + */ + EpropArchivingNodeRecurrent( const EpropArchivingNodeRecurrent& other ); + + /** + * Defines the pointer-to-member function type for the surrogate gradient function. + * + * @note The typename is `surrogate_gradient_function`. All parentheses in the expression are required. + */ + typedef double ( + EpropArchivingNodeRecurrent::*surrogate_gradient_function )( double, double, double, double, double ); + + /** + * Selects a surrogate gradient function based on the specified name. + * + * @param surrogate_gradient_function_name The name of the surrogate gradient function. + * @return The selected surrogate gradient function. + */ + surrogate_gradient_function select_surrogate_gradient( const std::string& surrogate_gradient_function_name ); + + /** + * @brief Computes the surrogate gradient with a piecewise linear function around the spike time. + * + * The piecewise linear surrogate function is used, for example, in Bellec et al. (2020). + * + * @param r The number of remaining refractory steps. + * @param v_m The membrane voltage. + * @param v_th The spike threshold voltage. For adaptive neurons, the adaptive spike threshold voltage. + * @param beta The width scaling of the surrogate gradient function. + * @param gamma The height scaling of the surrogate gradient function. + * @return The surrogate gradient of the membrane voltage. + */ + double compute_piecewise_linear_surrogate_gradient( const double r, + const double v_m, + const double v_th, + const double beta, + const double gamma ); + + /** + * @brief Computes the surrogate gradient with an exponentially decaying function around the spike time. + * + * The exponential surrogate function is used, for example, in Shrestha and Orchard (2018). + * + * @param r The number of remaining refractory steps. + * @param v_m The membrane voltage. + * @param v_th The threshold membrane voltage. For adaptive neurons, this is the adaptive threshold. + * @param v_th The spike threshold voltage. For adaptive neurons, the adaptive spike threshold voltage. + * @param beta The width scaling of the surrogate gradient function. + * @param gamma The height scaling of the surrogate gradient function. + * + * @return The surrogate gradient of the membrane voltage. + */ + double compute_exponential_surrogate_gradient( const double r, + const double v_m, + const double v_th, + const double beta, + const double gamma ); + + /** + * @brief Computes the surrogate gradient with a function reflecting the derivative of a fast sigmoid around the spike + * time. + * + * The derivative of fast sigmoid surrogate function is used, for example, in Zenke and Ganguli (2018). + * + * @param r The number of remaining refractory steps. + * @param v_m The membrane voltage. + * @param v_th The spike threshold voltage. For adaptive neurons, the adaptive spike threshold voltage. + * @param beta The width scaling of the surrogate gradient function. + * @param gamma The height scaling of the surrogate gradient function. + * + * @return The surrogate gradient of the membrane voltage. + */ + double compute_fast_sigmoid_derivative_surrogate_gradient( const double r, + const double v_m, + const double v_th, + const double beta, + const double gamma ); + + /** + * @brief Computes the surrogate gradient with an inverse tangent function around the spike time. + * + * The inverse tangent surrogate gradient function is used, for example, in Fang et al. (2021). + * + * @param r The number of remaining refractory steps. + * @param v_m The membrane voltage. + * @param v_th The spike threshold voltage. For adaptive neurons, the adaptive spike threshold voltage. + * @param beta The width scaling of the surrogate gradient function. + * @param gamma The height scaling of the surrogate gradient function. + * + * @return The surrogate gradient of the membrane voltage. + */ + double compute_arctan_surrogate_gradient( const double r, + const double v_m, + const double v_th, + const double beta, + const double gamma ); + + /** + * Creates an entry for the specified time step at the end of the eprop history. + * + * @param time_step The time step. + */ + void append_new_eprop_history_entry( const long time_step ); + + /** + * Writes the surrogate gradient to the eprop history entry at the specified time step. + * + * @param time_step The time step. + * @param surrogate_gradient The surrogate gradient. + */ void write_surrogate_gradient_to_history( const long time_step, const double surrogate_gradient ); - //! Update the learning signal in the eprop history entry of the given time step by writing the value of the incoming - //! learning signal to the history or adding it to the existing value in case of multiple readout neurons. - void write_learning_signal_to_history( const long time_step, const double learning_signal ); - - //! Create an entry in the firing rate regularization history for the current update. + /** + * @brief Writes the learning signal to the eprop history entry at the specifed time step. + * + * Updates the learning signal in the eprop history entry of the specified time step by writing the value of the + * incoming learning signal to the history or adding it to the existing value in case of multiple readout + * neurons. + * + * @param time_step The time step. + * @param learning_signal The learning signal. + * @param has_norm_step Flag indicating if an extra time step is used for communication between readout + * neurons to normalize the readout signal outputs, as for softmax. + */ + void write_learning_signal_to_history( const long time_step, + const double learning_signal, + const bool has_norm_step = true ); + + /** + * Calculates the firing rate regularization for the current update and writes it to a new entry in the firing rate + * regularization history. + * + * @param t_current_update The current update time. + * @param f_target The target firing rate. + * @param c_reg The firing rate regularization coefficient. + */ void write_firing_rate_reg_to_history( const long t_current_update, const double f_target, const double c_reg ); - //! Get an iterator pointing to the firing rate regularization history of the given time step. - std::vector< HistEntryEpropFiringRateReg >::iterator get_firing_rate_reg_history( const long time_step ); - - //! Return learning signal from history for given time step or zero if time step not in history - double get_learning_signal_from_history( const long time_step ); - - //! Erase parts of the firing rate regularization history for which the access counter in the update history has - //! decreased to zero since no synapse needs them any longer. + /** + * Calculates the current firing rate regularization and writes it to the eprop history at the specified time step. + * + * @param time_step The time step. + * @param z The spike state variable. + * @param f_target The target firing rate. + * @param kappa_reg The low-pass filter of the firing rate regularization. + * @param c_reg The firing rate regularization coefficient. + */ + void write_firing_rate_reg_to_history( const long time_step, + const double z, + const double f_target, + const double kappa_reg, + const double c_reg ); + + /** + * Retrieves the firing rate regularization at the specified time step from the firing rate regularization history. + * + * @param time_step The time step. + * + * @return The firing rate regularization at the specified time step. + */ + double get_firing_rate_reg_history( const long time_step ); + + /** + * Retrieves the learning signal from the eprop history at the specified time step. + * + * @param time_step The time step. + * @param has_norm_step Flag indicating if an extra time step is used for communication between readout neurons to + * normalize the readout signal outputs, as for softmax. + * + * @return The learning signal at the specified time step or zero if time step is not in the history. + */ + double get_learning_signal_from_history( const long time_step, const bool has_norm_step = true ); + + /** + * @brief Erases the history of the used firing rate regularization history. + * + * Erases parts of the firing rate regularization history for which the access counter in the update history has + * decreased to zero since no synapse needs them any longer. + */ void erase_used_firing_rate_reg_history(); - //! Count emitted spike for the firing rate regularization. + /** + * Counts an emitted spike for the firing rate regularization. + */ void count_spike(); - //! Reset spike count for the firing rate regularization. + /** + * Resets the spike count for the firing rate regularization. + */ void reset_spike_count(); + //! Firing rate regularization. + double firing_rate_reg_; + + //! Average firing rate. + double f_av_; + private: //! Count of the emitted spikes for the firing rate regularization. size_t n_spikes_; //! History of the firing rate regularization. std::vector< HistEntryEpropFiringRateReg > firing_rate_reg_history_; + + /** + * Maps provided names of surrogate gradients to corresponding pointers to member functions. + * + * @todo In the long run, this map should be handled by a manager with proper registration functions, + * so that external modules can add their own gradient functions. + */ + static std::map< std::string, surrogate_gradient_function > surrogate_gradient_funcs_; }; inline void @@ -169,14 +382,37 @@ EpropArchivingNodeRecurrent::reset_spike_count() class EpropArchivingNodeReadout : public EpropArchivingNode< HistEntryEpropReadout > { public: - //! Default constructor. + /** + * Constructs a new EpropArchivingNodeReadout object. + */ EpropArchivingNodeReadout(); - //! Copy constructor. - EpropArchivingNodeReadout( const EpropArchivingNodeReadout& ); - - //! Create an entry in the eprop history for the given time step and error signal. - void write_error_signal_to_history( const long time_step, const double error_signal ); + /** + * Constructs a new EpropArchivingNodeReadout object by copying another EpropArchivingNodeReadout object. + * + * @param other The EpropArchivingNodeReadout object to copy. + */ + EpropArchivingNodeReadout( const EpropArchivingNodeReadout& other ); + + /** + * Creates an entry for the specified time step at the end of the eprop history. + * + * @param time_step The time step. + * @param has_norm_step Flag indicating if an extra time step is used for communication between readout neurons to + * normalize the readout signal outputs, as for softmax. + */ + void append_new_eprop_history_entry( const long time_step, const bool has_norm_step = true ); + + /** + * Writes the error signal to the eprop history at the specified time step. + * + * @param time_step The time step. + * @param error_signal The error signal. + * @param has_norm_step Flag indicating if an extra time step is used for communication between readout neurons to + * normalize the readout signal outputs, as for softmax. + */ + void + write_error_signal_to_history( const long time_step, const double error_signal, const bool has_norm_step = true ); }; } // namespace nest diff --git a/nestkernel/eprop_archiving_node_impl.h b/nestkernel/eprop_archiving_node_impl.h index e2798337a5..1b2e422284 100644 --- a/nestkernel/eprop_archiving_node_impl.h +++ b/nestkernel/eprop_archiving_node_impl.h @@ -50,17 +50,17 @@ EpropArchivingNode< HistEntryT >::EpropArchivingNode( const EpropArchivingNode& template < typename HistEntryT > void -EpropArchivingNode< HistEntryT >::register_eprop_connection() +EpropArchivingNode< HistEntryT >::register_eprop_connection( const bool is_bsshslm_2020_model ) { ++eprop_indegree_; - const long shift = get_shift(); + const long t_first_entry = is_bsshslm_2020_model ? get_shift() : -delay_rec_out_; - const auto it_hist = get_update_history( shift ); + const auto it_hist = get_update_history( t_first_entry ); - if ( it_hist == update_history_.end() or it_hist->t_ != shift ) + if ( it_hist == update_history_.end() or it_hist->t_ != t_first_entry ) { - update_history_.insert( it_hist, HistEntryEpropUpdate( shift, 1 ) ); + update_history_.insert( it_hist, HistEntryEpropUpdate( t_first_entry, 1 ) ); } else { @@ -70,14 +70,17 @@ EpropArchivingNode< HistEntryT >::register_eprop_connection() template < typename HistEntryT > void -EpropArchivingNode< HistEntryT >::write_update_to_history( const long t_previous_update, const long t_current_update ) +EpropArchivingNode< HistEntryT >::write_update_to_history( const long t_previous_update, + const long t_current_update, + const long eprop_isi_trace_cutoff, + const bool is_bsshslm_2020_model ) { if ( eprop_indegree_ == 0 ) { return; } - const long shift = get_shift(); + const long shift = is_bsshslm_2020_model ? get_shift() : -delay_rec_out_; const auto it_hist_curr = get_update_history( t_current_update + shift ); @@ -88,6 +91,10 @@ EpropArchivingNode< HistEntryT >::write_update_to_history( const long t_previous else { update_history_.insert( it_hist_curr, HistEntryEpropUpdate( t_current_update + shift, 1 ) ); + if ( not is_bsshslm_2020_model ) + { + erase_used_eprop_history( eprop_isi_trace_cutoff ); + } } const auto it_hist_prev = get_update_history( t_previous_update + shift ); @@ -96,6 +103,10 @@ EpropArchivingNode< HistEntryT >::write_update_to_history( const long t_previous { // If an entry exists for the previous update time, decrement its access counter --it_hist_prev->access_counter_; + if ( it_hist_prev->access_counter_ == 0 ) + { + update_history_.erase( it_hist_prev ); + } } } @@ -138,33 +149,37 @@ EpropArchivingNode< HistEntryT >::erase_used_eprop_history() } else { - const auto it_eprop_hist_from = get_eprop_history( t ); - const auto it_eprop_hist_to = get_eprop_history( t + update_interval ); - eprop_history_.erase( it_eprop_hist_from, it_eprop_hist_to ); // erase found entries since no longer used + // erase no longer needed entries for update intervals with no spikes sent to the target neuron + eprop_history_.erase( get_eprop_history( t ), get_eprop_history( t + update_interval ) ); } } - const auto it_eprop_hist_from = get_eprop_history( 0 ); - const auto it_eprop_hist_to = get_eprop_history( update_history_.begin()->t_ ); - eprop_history_.erase( it_eprop_hist_from, it_eprop_hist_to ); // erase found entries since no longer used + // erase no longer needed entries before the earliest current update + eprop_history_.erase( get_eprop_history( 0 ), get_eprop_history( update_history_.begin()->t_ ) ); } template < typename HistEntryT > void -EpropArchivingNode< HistEntryT >::erase_used_update_history() +EpropArchivingNode< HistEntryT >::erase_used_eprop_history( const long eprop_isi_trace_cutoff ) { - auto it_hist = update_history_.begin(); - while ( it_hist != update_history_.end() ) + if ( eprop_history_.empty() // nothing to remove + or update_history_.size() < 2 // no time markers to check + ) { - if ( it_hist->access_counter_ == 0 ) - { - // erase() invalidates the iterator, but returns a new, valid iterator - it_hist = update_history_.erase( it_hist ); - } - else - { - ++it_hist; - } + return; + } + + const long t_prev = ( update_history_.end() - 2 )->t_; + const long t_curr = ( update_history_.end() - 1 )->t_; + + if ( t_prev + eprop_isi_trace_cutoff < t_curr ) + { + // erase no longer needed entries to be ignored by trace cutoff + eprop_history_.erase( get_eprop_history( t_prev + eprop_isi_trace_cutoff ), get_eprop_history( t_curr ) ); } + + // erase no longer needed entries before the earliest current update + eprop_history_.erase( + get_eprop_history( std::numeric_limits< long >::min() ), get_eprop_history( update_history_.begin()->t_ - 1 ) ); } } // namespace nest diff --git a/nestkernel/histentry.cpp b/nestkernel/histentry.cpp index b56bb1ce87..a98a79346e 100644 --- a/nestkernel/histentry.cpp +++ b/nestkernel/histentry.cpp @@ -42,10 +42,14 @@ nest::HistEntryEprop::HistEntryEprop( long t ) { } -nest::HistEntryEpropRecurrent::HistEntryEpropRecurrent( long t, double surrogate_gradient, double learning_signal ) +nest::HistEntryEpropRecurrent::HistEntryEpropRecurrent( long t, + double surrogate_gradient, + double learning_signal, + double firing_rate_reg ) : HistEntryEprop( t ) , surrogate_gradient_( surrogate_gradient ) , learning_signal_( learning_signal ) + , firing_rate_reg_( firing_rate_reg ) { } diff --git a/nestkernel/histentry.h b/nestkernel/histentry.h index 0d5b1392bf..7b63c00a8f 100644 --- a/nestkernel/histentry.h +++ b/nestkernel/histentry.h @@ -92,10 +92,11 @@ operator<( const HistEntryEprop& he, long t ) class HistEntryEpropRecurrent : public HistEntryEprop { public: - HistEntryEpropRecurrent( long t, double surrogate_gradient, double learning_signal ); + HistEntryEpropRecurrent( long t, double surrogate_gradient, double learning_signal, double firing_rate_reg ); double surrogate_gradient_; double learning_signal_; + double firing_rate_reg_; }; /** diff --git a/nestkernel/nest_names.cpp b/nestkernel/nest_names.cpp index c897a9e6cf..cebe2f128e 100644 --- a/nestkernel/nest_names.cpp +++ b/nestkernel/nest_names.cpp @@ -176,6 +176,7 @@ const Name elements( "elements" ); const Name elementsize( "elementsize" ); const Name ellipsoidal( "ellipsoidal" ); const Name elliptical( "elliptical" ); +const Name eprop_isi_trace_cutoff( "eprop_isi_trace_cutoff" ); const Name eprop_learning_window( "eprop_learning_window" ); const Name eprop_reset_neurons_on_update( "eprop_reset_neurons_on_update" ); const Name eprop_update_interval( "eprop_update_interval" ); @@ -277,6 +278,8 @@ const Name instantiations( "instantiations" ); const Name interval( "interval" ); const Name is_refractory( "is_refractory" ); +const Name kappa( "kappa" ); +const Name kappa_reg( "kappa_reg" ); const Name Kd_act( "Kd_act" ); const Name Kd_IP3_1( "Kd_IP3_1" ); const Name Kd_IP3_2( "Kd_IP3_2" ); @@ -360,6 +363,7 @@ const Name offset( "offset" ); const Name offsets( "offsets" ); const Name omega( "omega" ); const Name optimizer( "optimizer" ); +const Name optimize_each_step( "optimize_each_step" ); const Name order( "order" ); const Name origin( "origin" ); const Name other( "other" ); diff --git a/nestkernel/nest_names.h b/nestkernel/nest_names.h index b6d964e04c..61a55380f3 100644 --- a/nestkernel/nest_names.h +++ b/nestkernel/nest_names.h @@ -203,6 +203,7 @@ extern const Name elements; extern const Name elementsize; extern const Name ellipsoidal; extern const Name elliptical; +extern const Name eprop_isi_trace_cutoff; extern const Name eprop_learning_window; extern const Name eprop_reset_neurons_on_update; extern const Name eprop_update_interval; @@ -305,6 +306,8 @@ extern const Name instantiations; extern const Name interval; extern const Name is_refractory; +extern const Name kappa; +extern const Name kappa_reg; extern const Name Kd_act; extern const Name Kd_IP3_1; extern const Name Kd_IP3_2; @@ -388,6 +391,7 @@ extern const Name offset; extern const Name offsets; extern const Name omega; extern const Name optimizer; +extern const Name optimize_each_step; extern const Name order; extern const Name origin; extern const Name other; diff --git a/nestkernel/node.cpp b/nestkernel/node.cpp index eb3c0f0497..482efc297d 100644 --- a/nestkernel/node.cpp +++ b/nestkernel/node.cpp @@ -219,7 +219,7 @@ Node::register_stdp_connection( double, double ) } void -Node::register_eprop_connection() +Node::register_eprop_connection( const bool ) { throw IllegalConnection( "The target node does not support eprop synapses." ); } @@ -231,7 +231,13 @@ Node::get_shift() const } void -Node::write_update_to_history( const long t_previous_update, const long t_current_update ) +Node::write_update_to_history( const long, const long, const long, const bool ) +{ + throw IllegalConnection( "The target node is not an e-prop neuron." ); +} + +long +Node::get_eprop_isi_trace_cutoff() const { throw IllegalConnection( "The target node is not an e-prop neuron." ); } @@ -543,6 +549,21 @@ nest::Node::get_tau_syn_in( int ) throw UnexpectedEvent(); } +void +nest::Node::compute_gradient( const long, + const long, + double&, + double&, + double&, + double&, + double&, + double&, + const CommonSynapseProperties&, + WeightOptimizer* ) +{ + throw IllegalConnection( "The target node does not support compute_gradient()." ); +} + double nest::Node::compute_gradient( std::vector< long >&, const long, const long, const double, const bool ) { diff --git a/nestkernel/node.h b/nestkernel/node.h index 9fde5624fb..75d7a0c953 100644 --- a/nestkernel/node.h +++ b/nestkernel/node.h @@ -32,6 +32,7 @@ #include // Includes from nestkernel: +#include "common_synapse_properties.h" #include "deprecation_warning.h" #include "event.h" #include "histentry.h" @@ -39,6 +40,7 @@ #include "nest_time.h" #include "nest_types.h" #include "secondary_event.h" +#include "weight_optimizer.h" // Includes from sli: #include "dictdatum.h" @@ -482,14 +484,22 @@ class Node virtual void register_stdp_connection( double, double ); /** - * Initialize the update history and register the eprop synapse. + * @brief Registers an eprop synapse and initializes the update history. + * + * The time for the first entry of the update history is set to the neuron specific shift if `is_bsshslm_2020` + * is true and to the negative transmission delay from the recurrent to the output layer otherwise. + * + * @param is_bsshslm_2020_model A boolean indicating whether the connection is for the bsshslm_2020 model(optional, + * default = true). * * @throws IllegalConnection */ - virtual void register_eprop_connection(); + virtual void register_eprop_connection( const bool is_bsshslm_2020_model = true ); /** - * Get the number of steps the time-point of the signal has to be shifted to + * @brief Retrieves the temporal shift of the signal. + * + * Retrieves the number of steps the time-point of the signal has to be shifted to * place it at the correct location in the e-prop-related histories. * * @note Unlike the original e-prop, where signals arise instantaneously, NEST @@ -497,29 +507,50 @@ class Node * compensate for the delays and synchronize the signals by shifting the * history. * + * @return The number of time steps to shift. + * * @throws IllegalConnection */ virtual long get_shift() const; /** - * Register current update in the update history and deregister previous update. + * Registers the current update in the update history and deregisters the previous update. + * + * @param t_previous_update The time step of the previous update. + * @param t_current_update The time step of the current update. + * @param eprop_isi_trace_cutoff The cutoff value for the eprop inter-spike interval trace (optional, default: 0). + * @param is_bsshslm_2020_model Flag indicating whether the model is the bsshslm_2020 model (optional, default = + * true). + * + * @throws IllegalConnection + */ + virtual void write_update_to_history( const long t_previous_update, + const long t_current_update, + const long eprop_isi_trace_cutoff = 0, + const bool is_bsshslm_2020_model = true ); + + /** + * Retrieves the maximum number of time steps integrated between two consecutive spikes. + * + * @return The cutoff value for the inter-spike interval eprop trace. * * @throws IllegalConnection */ - virtual void write_update_to_history( const long t_previous_update, const long t_current_update ); + virtual long get_eprop_isi_trace_cutoff() const; /** - * Return if the node is part of the recurrent network (and thus not a readout neuron). + * Checks if the node is part of the recurrent network and thus not a readout neuron. * * @note The e-prop synapse calls this function of the target node. If true, * it skips weight updates within the first interval step of the update * interval. * + * @return true if the node is an eprop recurrent node, false otherwise. + * * @throws IllegalConnection */ virtual bool is_eprop_recurrent_node() const; - /** * Handle incoming spike events. * @@ -804,9 +835,47 @@ class Node /** * Compute gradient change for eprop synapses. * - * This method is called from an eprop synapse on the eprop target neuron and returns the change in gradient. + * This method is called from an eprop synapse on the eprop target neuron. It updates various parameters related to + * e-prop plasticity according to Bellec et al. (2020) with additional biological features described in Korcsak-Gorzo, + * Stapmanns, and Espinoza Valverde et al. (in preparation). + * + * @param t_spike [in] Time of the current spike. + * @param t_spike_previous [in] Time of the previous spike. + * @param z_previous_buffer [in, out] Value of presynaptic spiking variable from previous time step. + * @param z_bar [in, out] Filtered presynaptic spiking variable. + * @param e_bar [in, out] Filtered eligibility trace. + * @param e_bar_reg [in, out] Filtered eligibility trace for firing rate regularization. + * @param epsilon [out] Component of eligibility vector corresponding to the adaptive firing threshold variable. + * @param weight [in, out] Synaptic weight. + * @param cp [in] Common properties for synapses. + * @param optimizer [in] Instance of weight optimizer. + * + */ + virtual void compute_gradient( const long t_spike, + const long t_spike_previous, + double& z_previous_buffer, + double& z_bar, + double& e_bar, + double& e_bar_reg, + double& epsilon, + double& weight, + const CommonSynapseProperties& cp, + WeightOptimizer* optimizer ); + + /** + * Compute gradient change for eprop synapses. + * + * This method is called from an eprop synapse on the eprop target neuron. It updates various parameters related to + * e-prop plasticity according to Bellec et al. (2020). + * + * @param presyn_isis [in, out] Vector of inter-spike intervals. + * @param t_previous_update [in] Time of the last update. + * @param t_previous_trigger_spike [in] Time of the last trigger spike. + * @param kappa [in] Decay factor for the eligibility trace. + * @param average_gradient [in] Boolean flag determining whether to compute an average of the gradients over the given + * period. * - * @params presyn_isis is cleared during call + * @return Returns the computed gradient value. */ virtual double compute_gradient( std::vector< long >& presyn_isis, const long t_previous_update, diff --git a/pynest/examples/eprop_plasticity/NMNIST_pixels_blocklist.txt b/pynest/examples/eprop_plasticity/NMNIST_pixels_blocklist.txt new file mode 100644 index 0000000000..e0c2627635 --- /dev/null +++ b/pynest/examples/eprop_plasticity/NMNIST_pixels_blocklist.txt @@ -0,0 +1,1116 @@ +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 +94 +95 +96 +97 +98 +99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +232 +233 +234 +235 +236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +267 +268 +269 +270 +271 +272 +273 +274 +275 +276 +277 +278 +301 +302 +303 +304 +305 +306 +307 +308 +309 +310 +311 +312 +335 +336 +337 +338 +339 +340 +341 +342 +343 +344 +345 +370 +371 +372 +373 +374 +375 +376 +377 +378 +379 +404 +405 +406 +407 +408 +409 +410 +411 +412 +413 +438 +439 +440 +441 +442 +444 +445 +446 +447 +471 +472 +473 +474 +475 +476 +477 +478 +479 +480 +481 +506 +507 +508 +509 +510 +511 +512 +513 +514 +539 +540 +541 +542 +543 +544 +545 +546 +547 +548 +549 +573 +574 +575 +576 +577 +578 +579 +580 +581 +582 +608 +609 +610 +611 +612 +613 +614 +615 +616 +617 +641 +642 +643 +644 +645 +646 +647 +648 +649 +650 +675 +676 +677 +678 +679 +680 +681 +682 +683 +684 +709 +710 +711 +712 +713 +714 +715 +716 +717 +718 +743 +744 +745 +746 +747 +748 +749 +750 +751 +752 +776 +777 +778 +779 +780 +781 +782 +783 +784 +785 +786 +810 +811 +812 +813 +814 +815 +816 +817 +818 +819 +820 +843 +844 +845 +846 +847 +848 +849 +850 +851 +852 +853 +854 +877 +878 +879 +880 +881 +882 +883 +884 +885 +886 +887 +888 +889 +910 +911 +912 +913 +914 +915 +916 +917 +918 +919 +920 +921 +922 +923 +943 +944 +945 +946 +947 +948 +949 +950 +951 +952 +953 +954 +955 +956 +957 +958 +976 +977 +978 +979 +980 +981 +982 +983 +984 +985 +986 +987 +988 +989 +990 +991 +992 +993 +1009 +1010 +1011 +1012 +1013 +1014 +1015 +1016 +1017 +1018 +1019 +1020 +1021 +1022 +1023 +1024 +1025 +1026 +1027 +1028 +1042 +1043 +1044 +1045 +1046 +1047 +1048 +1049 +1050 +1051 +1052 +1053 +1054 +1055 +1056 +1057 +1058 +1059 +1060 +1061 +1062 +1063 +1064 +1074 +1075 +1076 +1077 +1078 +1079 +1080 +1081 +1082 +1083 +1084 +1085 +1086 +1087 +1088 +1089 +1090 +1091 +1092 +1093 +1094 +1095 +1096 +1097 +1098 +1099 +1100 +1101 +1102 +1103 +1104 +1105 +1106 +1107 +1108 +1109 +1110 +1111 +1112 +1113 +1114 +1115 +1116 +1117 +1118 +1119 +1120 +1121 +1122 +1123 +1124 +1125 +1126 +1127 +1128 +1129 +1130 +1131 +1132 +1133 +1134 +1135 +1136 +1137 +1138 +1139 +1140 +1141 +1142 +1143 +1144 +1145 +1146 +1147 +1148 +1149 +1150 +1151 +1152 +1153 +1154 +1155 +1156 +1157 +1158 +1159 +1160 +1161 +1162 +1163 +1164 +1165 +1166 +1167 +1168 +1169 +1170 +1171 +1172 +1173 +1174 +1175 +1176 +1177 +1178 +1179 +1180 +1181 +1182 +1183 +1184 +1185 +1186 +1187 +1188 +1189 +1190 +1191 +1192 +1193 +1194 +1195 +1196 +1197 +1198 +1199 +1200 +1201 +1202 +1203 +1204 +1205 +1206 +1207 +1208 +1209 +1210 +1211 +1212 +1213 +1214 +1215 +1216 +1217 +1218 +1220 +1221 +1222 +1223 +1224 +1225 +1226 +1227 +1228 +1229 +1230 +1231 +1232 +1233 +1234 +1235 +1236 +1237 +1238 +1239 +1240 +1241 +1242 +1243 +1244 +1245 +1246 +1247 +1248 +1249 +1250 +1251 +1252 +1253 +1254 +1255 +1256 +1257 +1258 +1259 +1260 +1261 +1262 +1263 +1264 +1265 +1266 +1267 +1268 +1269 +1270 +1271 +1272 +1273 +1274 +1275 +1277 +1278 +1279 +1280 +1281 +1282 +1283 +1284 +1285 +1286 +1287 +1288 +1290 +1291 +1292 +1293 +1294 +1295 +1296 +1297 +1298 +1299 +1300 +1301 +1302 +1303 +1304 +1305 +1316 +1317 +1318 +1319 +1320 +1321 +1322 +1323 +1324 +1325 +1326 +1327 +1328 +1329 +1330 +1331 +1332 +1333 +1334 +1335 +1353 +1354 +1355 +1356 +1357 +1358 +1359 +1360 +1361 +1362 +1363 +1364 +1365 +1366 +1367 +1368 +1388 +1389 +1390 +1391 +1393 +1394 +1395 +1396 +1397 +1398 +1399 +1400 +1401 +1424 +1425 +1426 +1427 +1428 +1429 +1430 +1431 +1432 +1433 +1457 +1459 +1461 +1462 +1463 +1464 +1465 +1466 +1467 +1492 +1493 +1494 +1495 +1496 +1497 +1498 +1499 +1500 +1501 +1526 +1527 +1528 +1529 +1530 +1531 +1532 +1533 +1560 +1561 +1562 +1563 +1564 +1565 +1566 +1567 +1568 +1593 +1594 +1595 +1596 +1597 +1598 +1600 +1601 +1602 +1627 +1628 +1629 +1630 +1631 +1632 +1633 +1634 +1635 +1636 +1662 +1663 +1664 +1665 +1666 +1667 +1668 +1669 +1670 +1696 +1697 +1698 +1699 +1700 +1701 +1702 +1703 +1704 +1730 +1731 +1732 +1733 +1734 +1735 +1736 +1737 +1738 +1763 +1764 +1765 +1766 +1767 +1768 +1769 +1770 +1771 +1772 +1797 +1798 +1799 +1800 +1801 +1802 +1803 +1804 +1805 +1806 +1831 +1832 +1833 +1834 +1835 +1836 +1837 +1838 +1839 +1866 +1867 +1868 +1869 +1870 +1871 +1872 +1873 +1899 +1900 +1901 +1902 +1903 +1904 +1905 +1906 +1907 +1933 +1934 +1935 +1936 +1937 +1938 +1939 +1940 +1941 +1966 +1967 +1968 +1969 +1970 +1971 +1972 +1973 +1974 +1975 +2000 +2001 +2002 +2003 +2004 +2005 +2006 +2007 +2008 +2009 +2010 +2033 +2034 +2035 +2036 +2037 +2038 +2039 +2040 +2041 +2042 +2043 +2067 +2068 +2069 +2070 +2071 +2072 +2073 +2074 +2075 +2076 +2077 +2078 +2079 +2100 +2101 +2102 +2103 +2104 +2105 +2106 +2107 +2108 +2109 +2111 +2112 +2113 +2133 +2134 +2135 +2136 +2137 +2138 +2139 +2140 +2141 +2142 +2143 +2144 +2145 +2146 +2147 +2148 +2149 +2165 +2166 +2167 +2168 +2169 +2170 +2171 +2172 +2173 +2174 +2175 +2176 +2177 +2178 +2179 +2180 +2181 +2182 +2198 +2199 +2200 +2201 +2202 +2203 +2204 +2205 +2206 +2207 +2208 +2209 +2210 +2211 +2212 +2213 +2214 +2215 +2216 +2217 +2218 +2219 +2230 +2231 +2232 +2233 +2234 +2235 +2236 +2237 +2238 +2239 +2240 +2241 +2242 +2243 +2244 +2245 +2246 +2247 +2248 +2249 +2250 +2251 +2252 +2253 +2254 +2255 +2256 +2259 +2261 +2262 +2263 +2264 +2265 +2266 +2267 +2268 +2269 +2270 +2271 +2272 +2273 +2274 +2275 +2276 +2277 +2278 +2279 +2280 +2281 +2282 +2283 +2284 +2285 +2286 +2287 +2288 +2289 +2290 +2291 +2292 +2293 +2294 +2295 +2296 +2297 +2298 +2299 +2300 +2301 +2302 +2303 +2304 +2305 +2306 +2307 +2308 +2309 +2310 +2311 diff --git a/pynest/examples/eprop_plasticity/README.rst b/pynest/examples/eprop_plasticity/README.rst index 6a23010f0c..f6153c59fd 100644 --- a/pynest/examples/eprop_plasticity/README.rst +++ b/pynest/examples/eprop_plasticity/README.rst @@ -2,14 +2,34 @@ E-prop plasticity examples ========================== -.. image:: eprop_supervised_regression_schematic_sine-waves.png +.. image:: eprop_supervised_regression_sine-waves.png -Eligibility propagation (e-prop) [1]_ is a three-factor learning rule for spiking neural networks -that approximates backpropagation through time. The original TensorFlow implementation of e-prop -was demonstrated, among others, on a supervised regression task to generate temporal patterns and a -supervised classification task to accumulate evidence [2]_. Here, you find tutorials on how to -reproduce these two tasks as well as two more advanced regression tasks using the NEST implementation -of e-prop [3]_ and how to visualize the simulation recordings. +Eligibility propagation (e-prop) [1]_ is a three-factor learning rule for spiking neural networks that +approaches the performance of backpropagation through time (BPTT). The original TensorFlow implementation of +e-prop was demonstrated, among others, on a supervised regression task to generate temporal patterns and a +supervised classification task to accumulate evidence [2]_. Here, you find tutorials on how to reproduce these +two tasks as well as two more advanced regression tasks using the NEST implementation of e-prop [3]_ and how to +visualize the simulation recordings. + +The tutorials labeled "after Bellec et al. (2020)" use the original e-prop model [1]_, while the other +tutorials use a version of e-prop that includes additional biological features as described in [3]_. + +See below for a diagram that describes the relationships between the different models for e-prop. + +Users interested in endowing an existing model with e-prop plasticity, may compare the .cpp and .h files of the +:doc:`iaf_psc_delta` and :doc:`eprop_iaf_psc_delta` model. +Parameters to run the `eprop_iaf_psc_delta` model are provided in +:doc:`eprop_supervised_regression_sine-waves.py `. + +e-prop model map +---------------- + +.. grid:: + + .. grid-item-card:: + :columns: 12 + + .. image:: /static/img/eprop_model_diagram.svg References ---------- @@ -21,6 +41,6 @@ References .. [2] https://github.com/IGITUGraz/eligibility_propagation/blob/master/Figure_3_and_S7_e_prop_tutorials/ -.. [3] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Dahmen D, - van Albada SJ, Bolten M, Diesmann M. Event-based implementation of - eligibility propagation (in preparation) +.. [3] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Plesser HE, + Dahmen D, Bolten M, Van Albada SJ*, Diesmann M*. Event-based + implementation of eligibility propagation (in preparation) diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation.png b/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation.png new file mode 100644 index 0000000000..bad5ce5f4a Binary files /dev/null and b/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation.png differ diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation.py b/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation.py index a82a018e93..78b5290d2c 100644 --- a/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation.py +++ b/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation.py @@ -40,9 +40,9 @@ infer the underlying rationale of the task. Here, the solution is to turn to the side in which more cues were presented. -.. image:: eprop_supervised_classification_schematic_evidence-accumulation.png +.. image:: eprop_supervised_classification_evidence-accumulation.png :width: 70 % - :alt: See Figure 1 below. + :alt: Schematic of network architecture. Same as Figure 1 in the code. :align: center Learning in the neural network model is achieved by optimizing the connection weights with e-prop plasticity. @@ -55,7 +55,7 @@ compares the network signal :math:`\pi_k` with the teacher target signal :math:`\pi_k^*`, which it receives from a rate generator. Since the decision is at the end and all the cues are relevant, the network has to keep the cues in memory. Additional adaptive neurons in the network enable this memory. The network's training error is -assessed by employing a cross-entropy error loss. +assessed by employing a mean squared error loss. Details on the event-based NEST implementation of e-prop can be found in [3]_. @@ -68,8 +68,10 @@ .. [2] https://github.com/IGITUGraz/eligibility_propagation/blob/master/Figure_3_and_S7_e_prop_tutorials/tutorial_evidence_accumulation_with_alif.py -.. [3] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Dahmen D, van Albada SJ, Bolten M, Diesmann M. - Event-based implementation of eligibility propagation (in preparation) +.. [3] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Plesser HE, + Dahmen D, Bolten M, Van Albada SJ*, Diesmann M*. Event-based + implementation of eligibility propagation (in preparation) + """ # pylint: disable=line-too-long # noqa: E501 # %% ########################################################################################################### @@ -88,11 +90,11 @@ # Schematic of network architecture # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # This figure, identical to the one in the description, shows the required network architecture in the center, -# the input and output of the pattern generation task above, and lists of the required NEST device, neuron, and -# synapse models below. The connections that must be established are numbered 1 to 7. +# the input and output of the evidence accumulation task above, and lists of the required NEST device, neuron, +# and synapse models below. The connections that must be established are numbered 1 to 7. try: - Image(filename="./eprop_supervised_classification_schematic_evidence-accumulation.png") + Image(filename="./eprop_supervised_classification_evidence-accumulation.png") except Exception: pass @@ -113,15 +115,21 @@ # Define timing of task # ..................... # The task's temporal structure is then defined, once as time steps and once as durations in milliseconds. -# Using a batch size larger than one aids the network in generalization, facilitating the solution to this task. -# The original number of iterations requires distributed computing. - -n_batch = 1 # batch size, 64 in reference [2], 32 in the README to reference [2] -n_iter = 5 # number of iterations, 2000 in reference [2], 50 with n_batch 32 converges - -n_input_symbols = 4 # number of input populations, e.g. 4 = left, right, recall, noise -n_cues = 7 # number of cues given before decision -prob_group = 0.3 # probability with which one input group is present +# Even though each sample is processed independently during training, we aggregate predictions and true +# labels across a group of samples during the evaluation phase. The number of samples in this group is +# determined by the `group_size` parameter. This data is then used to assess the neural network's +# performance metrics, such as average accuracy and mean error. Increasing the number of iterations enhances +# learning performance up to the point where overfitting occurs. + +group_size = 32 # number of instances over which to evaluate the learning performance +n_iter = 50 # number of iterations + +input = { + "n_symbols": 4, # number of input populations, e.g. 4 = left, right, recall, noise + "n_cues": 7, # number of cues given before decision + "prob_group": 0.3, # probability with which one input group is present + "spike_prob": 0.04, # spike probability of frozen input noise +} steps = { "cue": 100, # time steps in one cue presentation @@ -130,22 +138,20 @@ "recall": 150, # time steps of recall } -steps["cues"] = n_cues * (steps["cue"] + steps["spacing"]) # time steps of all cues +steps["cues"] = input["n_cues"] * (steps["cue"] + steps["spacing"]) # time steps of all cues steps["sequence"] = steps["cues"] + steps["bg_noise"] + steps["recall"] # time steps of one full sequence steps["learning_window"] = steps["recall"] # time steps of window with non-zero learning signals -steps["task"] = n_iter * n_batch * steps["sequence"] # time steps of task +steps["task"] = n_iter * group_size * steps["sequence"] # time steps of task steps.update( { "offset_gen": 1, # offset since generator signals start from time step 1 "delay_in_rec": 1, # connection delay between input and recurrent neurons - "delay_rec_out": 1, # connection delay between recurrent and output neurons - "delay_out_norm": 1, # connection delay between output neurons for normalization - "extension_sim": 1, # extra time step to close right-open simulation time interval in Simulate() + "extension_sim": 3, # extra time step to close right-open simulation time interval in Simulate() } ) -steps["delays"] = steps["delay_in_rec"] + steps["delay_rec_out"] + steps["delay_out_norm"] # time steps of delays +steps["delays"] = steps["delay_in_rec"] # time steps of delays steps["total_offset"] = steps["offset_gen"] + steps["delays"] # time steps of total offset @@ -159,12 +165,9 @@ # Set up simulation # ................. # As last step of the setup, we reset the NEST kernel to remove all existing NEST simulation settings and -# objects and set some NEST kernel parameters, some of which are e-prop-related. +# objects and set some NEST kernel parameters. params_setup = { - "eprop_learning_window": duration["learning_window"], - "eprop_reset_neurons_on_update": True, # if True, reset dynamic variables at start of each update interval - "eprop_update_interval": duration["sequence"], # ms, time interval for updating the synaptic weights "print_time": False, # if True, print time progress bar during simulation, set False if run as code cell "resolution": duration["step"], "total_num_virtual_procs": 1, # number of virtual processes, set in case of distributed computing @@ -189,31 +192,48 @@ n_rec = n_ad + n_reg # number of recurrent neurons n_out = 2 # number of readout neurons - -params_nrn_reg = { +params_nrn_out = { "C_m": 1.0, # pF, membrane capacitance - takes effect only if neurons get current input (here not the case) - "c_reg": 2.0, # firing rate regularization scaling - double the TF c_reg for technical reasons "E_L": 0.0, # mV, leak / resting membrane potential - "f_target": 10.0, # spikes/s, target firing rate for firing rate regularization - "gamma": 0.3, # scaling of the pseudo derivative + "eprop_isi_trace_cutoff": 100, # cutoff of integration of eprop trace between spikes "I_e": 0.0, # pA, external current input - "regular_spike_arrival": True, # If True, input spikes arrive at end of time step, if False at beginning - "surrogate_gradient_function": "piecewise_linear", # surrogate gradient / pseudo-derivative function - "t_ref": 5.0, # ms, duration of refractory period + "regular_spike_arrival": False, # If True, input spikes arrive at end of time step, if False at beginning "tau_m": 20.0, # ms, membrane time constant "V_m": 0.0, # mV, initial value of the membrane voltage +} + +params_nrn_reg = { + "beta": 1.7, # width scaling of the pseudo-derivative + "C_m": 1.0, + "c_reg": 300.0 / duration["sequence"] * duration["learning_window"], # coefficient of firing rate regularization + "E_L": 0.0, + "eprop_isi_trace_cutoff": 100, + "f_target": 10.0, # spikes/s, target firing rate for firing rate regularization + "gamma": 0.5, # height scaling of the pseudo-derivative + "I_e": 0.0, + "kappa": 0.97, # low-pass filter of the eligibility trace + "kappa_reg": 0.97, # low-pass filter of the firing rate for regularization + "regular_spike_arrival": True, + "surrogate_gradient_function": "piecewise_linear", # surrogate gradient / pseudo-derivative function + "t_ref": 5.0, # ms, duration of refractory period + "tau_m": 20.0, + "V_m": 0.0, "V_th": 0.6, # mV, spike threshold membrane voltage } params_nrn_ad = { + "beta": 1.7, "adapt_tau": 2000.0, # ms, time constant of adaptive threshold "adaptation": 0.0, # initial value of the spike threshold adaptation "C_m": 1.0, - "c_reg": 2.0, + "c_reg": 300.0 / duration["sequence"] * duration["learning_window"], "E_L": 0.0, + "eprop_isi_trace_cutoff": 100, # cutoff of integration of eprop trace between spikes "f_target": 10.0, - "gamma": 0.3, + "gamma": 0.5, "I_e": 0.0, + "kappa": 0.97, + "kappa_reg": 0.97, "regular_spike_arrival": True, "surrogate_gradient_function": "piecewise_linear", "t_ref": 5.0, @@ -227,16 +247,6 @@ / (1.0 - np.exp(-duration["step"] / params_nrn_ad["tau_m"])) ) # prefactor of adaptive threshold -params_nrn_out = { - "C_m": 1.0, - "E_L": 0.0, - "I_e": 0.0, - "loss": "cross_entropy", # loss function - "regular_spike_arrival": False, - "tau_m": 20.0, - "V_m": 0.0, -} - #################### # Intermediate parrot neurons required between input spike generators and recurrent neurons, @@ -245,13 +255,11 @@ gen_spk_in = nest.Create("spike_generator", n_in) nrns_in = nest.Create("parrot_neuron", n_in) -# The suffix _bsshslm_2020 follows the NEST convention to indicate in the model name the paper -# that introduced it by the first letter of the authors' last names and the publication year. - -nrns_reg = nest.Create("eprop_iaf_bsshslm_2020", n_reg, params_nrn_reg) -nrns_ad = nest.Create("eprop_iaf_adapt_bsshslm_2020", n_ad, params_nrn_ad) -nrns_out = nest.Create("eprop_readout_bsshslm_2020", n_out, params_nrn_out) +nrns_reg = nest.Create("eprop_iaf", n_reg, params_nrn_reg) +nrns_ad = nest.Create("eprop_iaf_adapt", n_ad, params_nrn_ad) +nrns_out = nest.Create("eprop_readout", n_out, params_nrn_out) gen_rate_target = nest.Create("step_rate_generator", n_out) +gen_learning_window = nest.Create("step_rate_generator") nrns_rec = nrns_reg + nrns_ad @@ -265,7 +273,7 @@ # default, recordings are stored in memory but can also be written to file. n_record = 1 # number of neurons per type to record dynamic variables from - this script requires n_record >= 1 -n_record_w = 3 # number of senders and targets to record weights from - this script requires n_record_w >=1 +n_record_w = 5 # number of senders and targets to record weights from - this script requires n_record_w >=1 if n_record == 0 or n_record_w == 0: raise ValueError("n_record and n_record_w >= 1 required") @@ -275,6 +283,7 @@ "record_from": ["V_m", "surrogate_gradient", "learning_signal"], # dynamic variables to record "start": duration["offset_gen"] + duration["delay_in_rec"], # start time of recording "stop": duration["offset_gen"] + duration["delay_in_rec"] + duration["task"], # stop time of recording + "label": "multimeter_reg", } params_mm_ad = { @@ -282,13 +291,15 @@ "record_from": params_mm_reg["record_from"] + ["V_th_adapt", "adaptation"], "start": duration["offset_gen"] + duration["delay_in_rec"], "stop": duration["offset_gen"] + duration["delay_in_rec"] + duration["task"], + "label": "multimeter_ad", } params_mm_out = { "interval": duration["step"], - "record_from": ["V_m", "readout_signal", "readout_signal_unnorm", "target_signal", "error_signal"], + "record_from": ["V_m", "readout_signal", "target_signal", "error_signal"], "start": duration["total_offset"], "stop": duration["total_offset"] + duration["task"], + "label": "multimeter_out", } params_wr = { @@ -296,11 +307,25 @@ "targets": nrns_rec[:n_record_w] + nrns_out, # limit targets to subsample weights to record from "start": duration["total_offset"], "stop": duration["total_offset"] + duration["task"], + "label": "weight_recorder", } -params_sr = { - "start": duration["total_offset"], +params_sr_in = { + "start": duration["offset_gen"], + "stop": duration["total_offset"] + duration["task"], + "label": "spike_recorder_in", +} + +params_sr_reg = { + "start": duration["offset_gen"], "stop": duration["total_offset"] + duration["task"], + "label": "spike_recorder_reg", +} + +params_sr_ad = { + "start": duration["offset_gen"], + "stop": duration["total_offset"] + duration["task"], + "label": "spike_recorder_ad", } #################### @@ -308,7 +333,9 @@ mm_reg = nest.Create("multimeter", params_mm_reg) mm_ad = nest.Create("multimeter", params_mm_ad) mm_out = nest.Create("multimeter", params_mm_out) -sr = nest.Create("spike_recorder", params_sr) +sr_in = nest.Create("spike_recorder", params_sr_in) +sr_reg = nest.Create("spike_recorder", params_sr_reg) +sr_ad = nest.Create("spike_recorder", params_sr_ad) wr = nest.Create("weight_recorder", params_wr) nrns_reg_record = nrns_reg[:n_record] @@ -342,22 +369,22 @@ def calculate_glorot_dist(fan_in, fan_out): params_common_syn_eprop = { "optimizer": { "type": "adam", # algorithm to optimize the weights - "batch_size": n_batch, + "batch_size": 1, "beta_1": 0.9, # exponential decay rate for 1st moment estimate of Adam optimizer "beta_2": 0.999, # exponential decay rate for 2nd moment raw estimate of Adam optimizer "epsilon": 1e-8, # small numerical stabilization constant of Adam optimizer - "eta": 5e-3, # learning rate + "eta": 5e-3 / duration["learning_window"], # learning rate + "optimize_each_step": True, # call optimizer every time step (True) or once per spike (False); only + # True implements original Adam algorithm, False offers speed-up; choice can affect learning performance "Wmin": -100.0, # pA, minimal limit of the synaptic weights "Wmax": 100.0, # pA, maximal limit of the synaptic weights }, - "average_gradient": True, # if True, average the gradient over the learning window "weight_recorder": wr, } params_syn_base = { - "synapse_model": "eprop_synapse_bsshslm_2020", + "synapse_model": "eprop_synapse", "delay": duration["step"], # ms, dendritic delay - "tau_m_readout": params_nrn_out["tau_m"], # ms, for technical reasons pass readout neuron membrane time constant } params_syn_in = params_syn_base.copy() @@ -369,18 +396,16 @@ def calculate_glorot_dist(fan_in, fan_out): params_syn_out = params_syn_base.copy() params_syn_out["weight"] = weights_rec_out - params_syn_feedback = { - "synapse_model": "eprop_learning_signal_connection_bsshslm_2020", + "synapse_model": "eprop_learning_signal_connection", "delay": duration["step"], "weight": weights_out_rec, } -params_syn_out_out = { +params_syn_learning_window = { "synapse_model": "rate_connection_delayed", "delay": duration["step"], - "receptor_type": 1, # receptor type of readout neuron to receive other readout neuron's signals for softmax - "weight": 1.0, # pA, weight 1.0 required for correct softmax computation for technical reasons + "receptor_type": 1, # receptor type over which readout neuron receives learning window signal } params_syn_rate_target = { @@ -403,7 +428,7 @@ def calculate_glorot_dist(fan_in, fan_out): #################### -nest.SetDefaults("eprop_synapse_bsshslm_2020", params_common_syn_eprop) +nest.SetDefaults("eprop_synapse", params_common_syn_eprop) nest.Connect(gen_spk_in, nrns_in, params_conn_one_to_one, params_syn_static) # connection 1 nest.Connect(nrns_in, nrns_rec, params_conn_all_to_all, params_syn_in) # connection 2 @@ -411,9 +436,11 @@ def calculate_glorot_dist(fan_in, fan_out): nest.Connect(nrns_rec, nrns_out, params_conn_all_to_all, params_syn_out) # connection 4 nest.Connect(nrns_out, nrns_rec, params_conn_all_to_all, params_syn_feedback) # connection 5 nest.Connect(gen_rate_target, nrns_out, params_conn_one_to_one, params_syn_rate_target) # connection 6 -nest.Connect(nrns_out, nrns_out, params_conn_all_to_all, params_syn_out_out) # connection 7 +nest.Connect(gen_learning_window, nrns_out, params_conn_all_to_all, params_syn_learning_window) # connection 7 -nest.Connect(nrns_in + nrns_rec, sr, params_conn_all_to_all, params_syn_static) +nest.Connect(nrns_in, sr_in, params_conn_all_to_all, params_syn_static) +nest.Connect(nrns_reg, sr_reg, params_conn_all_to_all, params_syn_static) +nest.Connect(nrns_ad, sr_ad, params_conn_all_to_all, params_syn_static) nest.Connect(mm_reg, nrns_reg_record, params_conn_all_to_all, params_syn_static) nest.Connect(mm_ad, nrns_ad_record, params_conn_all_to_all, params_syn_static) @@ -433,25 +460,23 @@ def calculate_glorot_dist(fan_in, fan_out): # assigned randomly to the left or right. -def generate_evidence_accumulation_input_output( - n_batch, n_in, prob_group, input_spike_prob, n_cues, n_input_symbols, steps -): - n_pop_nrn = n_in // n_input_symbols +def generate_evidence_accumulation_input_output(batch_size, n_in, steps, input): + n_pop_nrn = n_in // input["n_symbols"] - prob_choices = np.array([prob_group, 1 - prob_group], dtype=np.float32) - idx = np.random.choice([0, 1], n_batch) - probs = np.zeros((n_batch, 2), dtype=np.float32) + prob_choices = np.array([input["prob_group"], 1 - input["prob_group"]], dtype=np.float32) + idx = np.random.choice([0, 1], batch_size) + probs = np.zeros((batch_size, 2), dtype=np.float32) probs[:, 0] = prob_choices[idx] probs[:, 1] = prob_choices[1 - idx] - batched_cues = np.zeros((n_batch, n_cues), dtype=int) - for b_idx in range(n_batch): - batched_cues[b_idx, :] = np.random.choice([0, 1], n_cues, p=probs[b_idx]) + batched_cues = np.zeros((batch_size, input["n_cues"]), dtype=int) + for b_idx in range(batch_size): + batched_cues[b_idx, :] = np.random.choice([0, 1], input["n_cues"], p=probs[b_idx]) - input_spike_probs = np.zeros((n_batch, steps["sequence"], n_in)) + input_spike_probs = np.zeros((batch_size, steps["sequence"], n_in)) - for b_idx in range(n_batch): - for c_idx in range(n_cues): + for b_idx in range(batch_size): + for c_idx in range(input["n_cues"]): cue = batched_cues[b_idx, c_idx] step_start = c_idx * (steps["cue"] + steps["spacing"]) + steps["spacing"] @@ -460,31 +485,28 @@ def generate_evidence_accumulation_input_output( pop_nrn_start = cue * n_pop_nrn pop_nrn_stop = pop_nrn_start + n_pop_nrn - input_spike_probs[b_idx, step_start:step_stop, pop_nrn_start:pop_nrn_stop] = input_spike_prob + input_spike_probs[b_idx, step_start:step_stop, pop_nrn_start:pop_nrn_stop] = input["spike_prob"] - input_spike_probs[:, -steps["recall"] :, 2 * n_pop_nrn : 3 * n_pop_nrn] = input_spike_prob - input_spike_probs[:, :, 3 * n_pop_nrn :] = input_spike_prob / 4.0 + input_spike_probs[:, -steps["recall"] :, 2 * n_pop_nrn : 3 * n_pop_nrn] = input["spike_prob"] + input_spike_probs[:, :, 3 * n_pop_nrn :] = input["spike_prob"] / 4.0 input_spike_bools = input_spike_probs > np.random.rand(input_spike_probs.size).reshape(input_spike_probs.shape) input_spike_bools[:, 0, :] = 0 # remove spikes in 0th time step of every sequence for technical reasons - target_cues = np.zeros(n_batch, dtype=int) - target_cues[:] = np.sum(batched_cues, axis=1) > int(n_cues / 2) + target_cues = np.zeros(batch_size, dtype=int) + target_cues[:] = np.sum(batched_cues, axis=1) > int(input["n_cues"] / 2) return input_spike_bools, target_cues -input_spike_prob = 0.04 # spike probability of frozen input noise dtype_in_spks = np.float32 # data type of input spikes - for reproducing TF results set to np.float32 input_spike_bools_list = [] target_cues_list = [] -for iteration in range(n_iter): - input_spike_bools, target_cues = generate_evidence_accumulation_input_output( - n_batch, n_in, prob_group, input_spike_prob, n_cues, n_input_symbols, steps - ) +for _ in range(n_iter): + input_spike_bools, target_cues = generate_evidence_accumulation_input_output(group_size, n_in, steps, input) input_spike_bools_list.append(input_spike_bools) - target_cues_list.extend(target_cues.tolist()) + target_cues_list.extend(target_cues) input_spike_bools_arr = np.array(input_spike_bools_list).reshape(steps["task"], n_in) timeline_task = np.arange(0.0, duration["task"], duration["step"]) + duration["offset_gen"] @@ -494,8 +516,8 @@ def generate_evidence_accumulation_input_output( for nrn_in_idx in range(n_in) ] -target_rate_changes = np.zeros((n_out, n_batch * n_iter)) -target_rate_changes[np.array(target_cues_list), np.arange(n_batch * n_iter)] = 1 +target_rate_changes = np.zeros((n_out, group_size * n_iter)) +target_rate_changes[np.array(target_cues_list), np.arange(group_size * n_iter)] = 1 params_gen_rate_target = [ { @@ -505,12 +527,38 @@ def generate_evidence_accumulation_input_output( for nrn_out_idx in range(n_out) ] - #################### nest.SetStatus(gen_spk_in, params_gen_spk_in) nest.SetStatus(gen_rate_target, params_gen_rate_target) +# %% ########################################################################################################### +# Create learning window +# ~~~~~~~~~~~~~~~~~~~~~~ +# Custom learning windows, in which the network learns, can be defined with an additional signal. The error +# signal is internally multiplied with this learning window signal. Passing a learning window signal of value 1 +# opens the learning window while passing a value of 0 closes it. + +amplitude_times = np.hstack( + [ + np.array([0.0, duration["sequence"] - duration["learning_window"]]) + + duration["total_offset"] + + i * duration["sequence"] + for i in range(group_size * n_iter) + ] +) + +amplitude_values = np.array([0.0, 1.0] * group_size * n_iter) + +params_gen_learning_window = { + "amplitude_times": amplitude_times, + "amplitude_values": amplitude_values, +} + +#################### + +nest.SetStatus(gen_learning_window, params_gen_learning_window) + # %% ########################################################################################################### # Force final update # ~~~~~~~~~~~~~~~~~~ @@ -574,29 +622,31 @@ def get_weights(pop_pre, pop_post): events_mm_reg = mm_reg.get("events") events_mm_ad = mm_ad.get("events") events_mm_out = mm_out.get("events") -events_sr = sr.get("events") +events_sr_in = sr_in.get("events") +events_sr_reg = sr_reg.get("events") +events_sr_ad = sr_ad.get("events") events_wr = wr.get("events") # %% ########################################################################################################### # Evaluate training error # ~~~~~~~~~~~~~~~~~~~~~~~ -# We evaluate the network's training error by calculating a loss - in this case, the cross-entropy error between +# We evaluate the network's training error by calculating a loss - in this case, the mean squared error between # the integrated recurrent network activity and the target rate. -readout_signal = events_mm_out["readout_signal"] # corresponds to softmax +readout_signal = events_mm_out["readout_signal"] target_signal = events_mm_out["target_signal"] senders = events_mm_out["senders"] readout_signal = np.array([readout_signal[senders == i] for i in set(senders)]) target_signal = np.array([target_signal[senders == i] for i in set(senders)]) -readout_signal = readout_signal.reshape((n_out, n_iter, n_batch, steps["sequence"])) -readout_signal = readout_signal[:, :, :, -steps["learning_window"] :] +readout_signal = readout_signal.reshape((n_out, n_iter, group_size, steps["sequence"])) +target_signal = target_signal.reshape((n_out, n_iter, group_size, steps["sequence"])) -target_signal = target_signal.reshape((n_out, n_iter, n_batch, steps["sequence"])) +readout_signal = readout_signal[:, :, :, -steps["learning_window"] :] target_signal = target_signal[:, :, :, -steps["learning_window"] :] -loss = -np.mean(np.sum(target_signal * np.log(readout_signal), axis=0), axis=(1, 2)) +loss = 0.5 * np.mean(np.sum((readout_signal - target_signal) ** 2, axis=3), axis=(0, 2)) y_prediction = np.argmax(np.mean(readout_signal, axis=3), axis=0) y_target = np.argmax(np.mean(target_signal, axis=3), axis=0) @@ -621,7 +671,6 @@ def get_weights(pop_pre, pop_post): plt.rcParams.update( { - "font.sans-serif": "Arial", "axes.spines.right": False, "axes.spines.top": False, "axes.prop_cycle": cycler(color=[colors["blue"], colors["red"]]), @@ -635,9 +684,10 @@ def get_weights(pop_pre, pop_post): # plotted against the iterations. fig, axs = plt.subplots(2, 1, sharex=True) +fig.suptitle("Training error") axs[0].plot(range(1, n_iter + 1), loss) -axs[0].set_ylabel(r"$E = -\sum_{t,k} \pi_k^{*,t} \log \pi_k^t$") +axs[0].set_ylabel(r"$E = \frac{1}{2} \sum_{t,k} \left( y_k^t -y_k^{*,t}\right)^2$") axs[1].plot(range(1, n_iter + 1), recall_errors) axs[1].set_ylabel("recall errors") @@ -665,11 +715,10 @@ def plot_recordable(ax, events, recordable, ylabel, xlims): ax.set_ylim(np.min(events[recordable]) - margin, np.max(events[recordable]) + margin) -def plot_spikes(ax, events, nrns, ylabel, xlims): +def plot_spikes(ax, events, ylabel, xlims): idc_times = (events["times"] > xlims[0]) & (events["times"] < xlims[1]) - idc_sender = np.isin(events["senders"][idc_times], nrns.tolist()) - senders_subset = events["senders"][idc_times][idc_sender] - times_subset = events["times"][idc_times][idc_sender] + senders_subset = events["senders"][idc_times] + times_subset = events["times"][idc_times] ax.scatter(times_subset, senders_subset, s=0.1) ax.set_ylabel(ylabel) @@ -677,17 +726,21 @@ def plot_spikes(ax, events, nrns, ylabel, xlims): ax.set_ylim(np.min(senders_subset) - margin, np.max(senders_subset) + margin) -for xlims in [(0, steps["sequence"]), (steps["task"] - steps["sequence"], steps["task"])]: +for title, xlims in zip( + ["Dynamic variables before training", "Dynamic variables after training"], + [(0, steps["sequence"]), (steps["task"] - steps["sequence"], steps["task"])], +): fig, axs = plt.subplots(14, 1, sharex=True, figsize=(8, 14), gridspec_kw={"hspace": 0.4, "left": 0.2}) + fig.suptitle(title) - plot_spikes(axs[0], events_sr, nrns_in, r"$z_i$" + "\n", xlims) - plot_spikes(axs[1], events_sr, nrns_reg, r"$z_j$" + "\n", xlims) + plot_spikes(axs[0], events_sr_in, r"$z_i$" + "\n", xlims) + plot_spikes(axs[1], events_sr_reg, r"$z_j$" + "\n", xlims) plot_recordable(axs[2], events_mm_reg, "V_m", r"$v_j$" + "\n(mV)", xlims) plot_recordable(axs[3], events_mm_reg, "surrogate_gradient", r"$\psi_j$" + "\n", xlims) plot_recordable(axs[4], events_mm_reg, "learning_signal", r"$L_j$" + "\n(pA)", xlims) - plot_spikes(axs[5], events_sr, nrns_ad, r"$z_j$" + "\n", xlims) + plot_spikes(axs[5], events_sr_ad, r"$z_j$" + "\n", xlims) plot_recordable(axs[6], events_mm_ad, "V_m", r"$v_j$" + "\n(mV)", xlims) plot_recordable(axs[7], events_mm_ad, "surrogate_gradient", r"$\psi_j$" + "\n", xlims) @@ -695,9 +748,9 @@ def plot_spikes(ax, events, nrns, ylabel, xlims): plot_recordable(axs[9], events_mm_ad, "learning_signal", r"$L_j$" + "\n(pA)", xlims) plot_recordable(axs[10], events_mm_out, "V_m", r"$v_k$" + "\n(mV)", xlims) - plot_recordable(axs[11], events_mm_out, "target_signal", r"$\pi^*_k$" + "\n", xlims) - plot_recordable(axs[12], events_mm_out, "readout_signal", r"$\pi_k$" + "\n", xlims) - plot_recordable(axs[13], events_mm_out, "error_signal", r"$\pi_k-\pi^*_k$" + "\n", xlims) + plot_recordable(axs[11], events_mm_out, "target_signal", r"$y^*_k$" + "\n", xlims) + plot_recordable(axs[12], events_mm_out, "readout_signal", r"$y_k$" + "\n", xlims) + plot_recordable(axs[13], events_mm_out, "error_signal", r"$y_k-y^*_k$" + "\n", xlims) axs[-1].set_xlabel(r"$t$ (ms)") axs[-1].set_xlim(*xlims) @@ -713,7 +766,10 @@ def plot_spikes(ax, events, nrns, ylabel, xlims): # the first time step and we add the initial weights manually. -def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabel): +def plot_weight_time_course(ax, events, nrns_weight_record, label, ylabel): + sender_label, target_label = label.split("_") + nrns_senders = nrns_weight_record[sender_label] + nrns_targets = nrns_weight_record[target_label] for sender in nrns_senders.tolist(): for target in nrns_targets.tolist(): idc_syn = (events["senders"] == sender) & (events["targets"] == target) @@ -726,16 +782,21 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe ax.step(times, weights, c=colors["blue"]) ax.set_ylabel(ylabel) - ax.set_ylim(-0.6, 0.6) + ax.set_ylim(-1.5, 1.5) fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4)) +fig.suptitle("Weight time courses") -plot_weight_time_course(axs[0], events_wr, nrns_in[:n_record_w], nrns_rec[:n_record_w], "in_rec", r"$W_\text{in}$ (pA)") -plot_weight_time_course( - axs[1], events_wr, nrns_rec[:n_record_w], nrns_rec[:n_record_w], "rec_rec", r"$W_\text{rec}$ (pA)" -) -plot_weight_time_course(axs[2], events_wr, nrns_rec[:n_record_w], nrns_out, "rec_out", r"$W_\text{out}$ (pA)") +nrns_weight_record = { + "in": nrns_in[:n_record_w], + "rec": nrns_rec[:n_record_w], + "out": nrns_out, +} + +plot_weight_time_course(axs[0], events_wr, nrns_weight_record, "in_rec", r"$W_\text{in}$ (pA)") +plot_weight_time_course(axs[1], events_wr, nrns_weight_record, "rec_rec", r"$W_\text{rec}$ (pA)") +plot_weight_time_course(axs[2], events_wr, nrns_weight_record, "rec_out", r"$W_\text{out}$ (pA)") axs[-1].set_xlabel(r"$t$ (ms)") axs[-1].set_xlim(0, steps["task"]) @@ -755,6 +816,7 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe ) fig, axs = plt.subplots(3, 2, sharex="col", sharey="row") +fig.suptitle("Weight matrices") all_w_extrema = [] @@ -777,8 +839,8 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe axs[2, 0].set_ylabel("readout\nneurons") fig.align_ylabels(axs[:, 0]) -axs[0, 0].text(0.5, 1.1, "pre-training", transform=axs[0, 0].transAxes, ha="center") -axs[0, 1].text(0.5, 1.1, "post-training", transform=axs[0, 1].transAxes, ha="center") +axs[0, 0].text(0.5, 1.1, "before training", transform=axs[0, 0].transAxes, ha="center") +axs[0, 1].text(0.5, 1.1, "after training", transform=axs[0, 1].transAxes, ha="center") axs[2, 0].yaxis.get_major_locator().set_params(integer=True) diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation_bsshslm_2020.png b/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation_bsshslm_2020.png new file mode 100644 index 0000000000..aaad5baf54 Binary files /dev/null and b/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation_bsshslm_2020.png differ diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation_bsshslm_2020.py b/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation_bsshslm_2020.py new file mode 100644 index 0000000000..a1f44c3711 --- /dev/null +++ b/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation_bsshslm_2020.py @@ -0,0 +1,959 @@ +# -*- coding: utf-8 -*- +# +# eprop_supervised_classification_evidence-accumulation_bsshslm_2020.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +r""" +Tutorial on learning to accumulate evidence with e-prop after Bellec et al. (2020) +---------------------------------------------------------------------------------- + +Training a classification model using supervised e-prop plasticity to accumulate evidence. + +Description +~~~~~~~~~~~ + +This script demonstrates supervised learning of a classification task with the eligibility propagation (e-prop) +plasticity mechanism by Bellec et al. [1]_. + +This type of learning is demonstrated at the proof-of-concept task in [1]_. We based this script on their +TensorFlow script given in [2]_. + +The task, a so-called evidence accumulation task, is inspired by behavioral tasks, where a lab animal (e.g., a +mouse) runs along a track, gets cues on the left and right, and has to decide at the end of the track between +taking a left and a right turn of which one is correct. After a number of iterations, the animal is able to +infer the underlying rationale of the task. Here, the solution is to turn to the side in which more cues were +presented. + +.. image:: eprop_supervised_classification_evidence-accumulation_bsshslm_2020.png + :width: 70 % + :alt: Schematic of network architecture. Same as Figure 1 in the code. + :align: center + +Learning in the neural network model is achieved by optimizing the connection weights with e-prop plasticity. +This plasticity rule requires a specific network architecture depicted in Figure 1. The neural network model +consists of a recurrent network that receives input from Poisson generators and projects onto two readout +neurons - one for the left and one for the right turn at the end. The input neuron population consists of four +groups: one group providing background noise of a specific rate for some base activity throughout the +experiment, one group providing the input spikes of the left cues and one group providing them for the right +cues, and a last group defining the recall window, in which the network has to decide. The readout neuron +compares the network signal :math:`\pi_k` with the teacher target signal :math:`\pi_k^*`, which it receives from +a rate generator. Since the decision is at the end and all the cues are relevant, the network has to keep the +cues in memory. Additional adaptive neurons in the network enable this memory. The network's training error is +assessed by employing a cross-entropy error loss. + +Details on the event-based NEST implementation of e-prop can be found in [3]_. + +References +~~~~~~~~~~ + +.. [1] Bellec G, Scherr F, Subramoney F, Hajek E, Salaj D, Legenstein R, Maass W (2020). A solution to the + learning dilemma for recurrent networks of spiking neurons. Nature Communications, 11:3625. + https://doi.org/10.1038/s41467-020-17236-y + +.. [2] https://github.com/IGITUGraz/eligibility_propagation/blob/master/Figure_3_and_S7_e_prop_tutorials/tutorial_evidence_accumulation_with_alif.py + +.. [3] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Plesser HE, + Dahmen D, Bolten M, Van Albada SJ*, Diesmann M*. Event-based + implementation of eligibility propagation (in preparation) + +""" # pylint: disable=line-too-long # noqa: E501 + +# %% ########################################################################################################### +# Import libraries +# ~~~~~~~~~~~~~~~~ +# We begin by importing all libraries required for the simulation, analysis, and visualization. + +import matplotlib as mpl +import matplotlib.pyplot as plt +import nest +import numpy as np +from cycler import cycler +from IPython.display import Image + +# %% ########################################################################################################### +# Schematic of network architecture +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# This figure, identical to the one in the description, shows the required network architecture in the center, +# the input and output of the pattern generation task above, and lists of the required NEST device, neuron, and +# synapse models below. The connections that must be established are numbered 1 to 7. + +try: + Image(filename="./eprop_supervised_classification_evidence-accumulation_bsshslm_2020.png") +except Exception: + pass + +# %% ########################################################################################################### +# Setup +# ~~~~~ + +# %% ########################################################################################################### +# Initialize random generator +# ........................... +# We seed the numpy random generator, which will generate random initial weights as well as random input and +# output. + +rng_seed = 1 # numpy random seed +np.random.seed(rng_seed) # fix numpy random seed + +# %% ########################################################################################################### +# Define timing of task +# ..................... +# The task's temporal structure is then defined, once as time steps and once as durations in milliseconds. +# Using a batch size larger than one aids the network in generalization, facilitating the solution to this task. +# The original number of iterations requires distributed computing. Increasing the number of iterations +# enhances learning performance up to the point where overfitting occurs. If early stopping is enabled, the +# classification error is tested in regular intervals and the training stopped as soon as the error selected as +# stop criterion is reached. After training, the performance can be tested over a number of test iterations. + +batch_size = 32 # batch size, 64 in reference [2], 32 in the README to reference [2] +n_iter = 50 # number of iterations, 2000 in reference [2] + +input = { + "n_symbols": 4, # number of input populations, e.g. 4 = left, right, recall, noise + "n_cues": 7, # number of cues given before decision + "prob_group": 0.3, # probability with which one input group is present + "spike_prob": 0.04, # spike probability of frozen input noise +} + +do_early_stopping = True # if True, stop training as soon as stop criterion fulfilled +n_validate_every = 10 # number of training iterations before validation +n_early_stop = 8 # number of iterations to average over to evaluate early stopping condition +stop_crit = 0.07 # error value corresponding to stop criterion for early stopping + +n_test = 4 # number of iterations for final test + +n_val = np.ceil(n_iter / n_validate_every) +n_iter_max = int(n_iter + n_val + (n_val - 1) * (n_early_stop + 1) + n_test) + +steps = { + "cue": 100, # time steps in one cue presentation + "spacing": 50, # time steps of break between two cues + "bg_noise": 1050, # time steps of background noise + "recall": 150, # time steps of recall +} + +steps["cues"] = input["n_cues"] * (steps["cue"] + steps["spacing"]) # time steps of all cues +steps["sequence"] = steps["cues"] + steps["bg_noise"] + steps["recall"] # time steps of one full sequence +steps["learning_window"] = steps["recall"] # time steps of window with non-zero learning signals + +steps.update( + { + "offset_gen": 1, # offset since generator signals start from time step 1 + "delay_in_rec": 1, # connection delay between input and recurrent neurons + "delay_rec_out": 1, # connection delay between recurrent and output neurons + "delay_out_norm": 1, # connection delay between output neurons for normalization + "extension_sim": 2, # extra time step to close right-open simulation time interval in Simulate() + } +) + +steps["delays"] = steps["delay_in_rec"] + steps["delay_rec_out"] + steps["delay_out_norm"] # time steps of delays + +steps["total_offset"] = steps["offset_gen"] + steps["delays"] # time steps of total offset + +duration = {"step": 1.0} # ms, temporal resolution of the simulation + +duration.update({key: value * duration["step"] for key, value in steps.items()}) # ms, durations + +# %% ########################################################################################################### +# Set up simulation +# ................. +# As last step of the setup, we reset the NEST kernel to remove all existing NEST simulation settings and +# objects and set some NEST kernel parameters, some of which are e-prop-related. + +params_setup = { + "eprop_learning_window": duration["learning_window"], + "eprop_reset_neurons_on_update": True, # if True, reset dynamic variables at start of each update interval + "eprop_update_interval": duration["sequence"], # ms, time interval for updating the synaptic weights + "print_time": False, # if True, print time progress bar during simulation, set False if run as code cell + "resolution": duration["step"], + "total_num_virtual_procs": 1, # number of virtual processes, set in case of distributed computing +} + +#################### + +nest.ResetKernel() +nest.set(**params_setup) + +# %% ########################################################################################################### +# Create neurons +# ~~~~~~~~~~~~~~ +# We proceed by creating a certain number of input, recurrent, and readout neurons and setting their parameters. +# Additionally, we already create an input spike generator and an output target rate generator, which we will +# configure later. Within the recurrent network, alongside a population of regular neurons, we introduce a +# population of adaptive neurons, to enhance the network's memory retention. + +n_in = 40 # number of input neurons +n_ad = 50 # number of adaptive neurons +n_reg = 50 # number of regular neurons +n_rec = n_ad + n_reg # number of recurrent neurons +n_out = 2 # number of readout neurons + +params_nrn_out = { + "C_m": 1.0, # pF, membrane capacitance - takes effect only if neurons get current input (here not the case) + "E_L": 0.0, # mV, leak / resting membrane potential + "I_e": 0.0, # pA, external current input + "loss": "cross_entropy", # loss function + "regular_spike_arrival": False, # If True, input spikes arrive at end of time step, if False at beginning + "tau_m": 20.0, # ms, membrane time constant + "V_m": 0.0, # mV, initial value of the membrane voltage +} + +params_nrn_reg = { + "beta": 1.0, # width scaling of the pseudo-derivative + "C_m": 1.0, + "c_reg": 300.0, # coefficient of firing rate regularization - 2*learning_window*(TF c_reg) for technical reasons + "E_L": 0.0, + "f_target": 10.0, # spikes/s, target firing rate for firing rate regularization + "gamma": 0.3, # height scaling of the pseudo-derivative + "I_e": 0.0, + "regular_spike_arrival": True, + "surrogate_gradient_function": "piecewise_linear", # surrogate gradient / pseudo-derivative function + "t_ref": 5.0, # ms, duration of refractory period + "tau_m": 20.0, + "V_m": 0.0, + "V_th": 0.6, # mV, spike threshold membrane voltage +} + +# factors from the original pseudo-derivative definition are incorporated into the parameters +params_nrn_reg["gamma"] /= params_nrn_reg["V_th"] +params_nrn_reg["beta"] /= np.abs(params_nrn_reg["V_th"]) # prefactor is inside abs in the original definition + +params_nrn_ad = { + "beta": 1.0, + "adapt_tau": 2000.0, # ms, time constant of adaptive threshold + "adaptation": 0.0, # initial value of the spike threshold adaptation + "C_m": 1.0, + "c_reg": 300.0, + "E_L": 0.0, + "f_target": 10.0, + "gamma": 0.3, + "I_e": 0.0, + "regular_spike_arrival": True, + "surrogate_gradient_function": "piecewise_linear", + "t_ref": 5.0, + "tau_m": 20.0, + "V_m": 0.0, + "V_th": 0.6, +} + +params_nrn_ad["gamma"] /= params_nrn_ad["V_th"] +params_nrn_ad["beta"] /= np.abs(params_nrn_ad["V_th"]) + +params_nrn_ad["adapt_beta"] = 1.7 * ( + (1.0 - np.exp(-duration["step"] / params_nrn_ad["adapt_tau"])) + / (1.0 - np.exp(-duration["step"] / params_nrn_ad["tau_m"])) +) # prefactor of adaptive threshold + +#################### + +# Intermediate parrot neurons required between input spike generators and recurrent neurons, +# since devices cannot establish plastic synapses for technical reasons + +gen_spk_in = nest.Create("spike_generator", n_in) +nrns_in = nest.Create("parrot_neuron", n_in) + +# The suffix _bsshslm_2020 follows the NEST convention to indicate in the model name the paper +# that introduced it by the first letter of the authors' last names and the publication year. + +nrns_reg = nest.Create("eprop_iaf_bsshslm_2020", n_reg, params_nrn_reg) +nrns_ad = nest.Create("eprop_iaf_adapt_bsshslm_2020", n_ad, params_nrn_ad) +nrns_out = nest.Create("eprop_readout_bsshslm_2020", n_out, params_nrn_out) +gen_rate_target = nest.Create("step_rate_generator", n_out) + +nrns_rec = nrns_reg + nrns_ad + +# %% ########################################################################################################### +# Create recorders +# ~~~~~~~~~~~~~~~~ +# We also create recorders, which, while not required for the training, will allow us to track various dynamic +# variables of the neurons, spikes, and changes in synaptic weights. To save computing time and memory, the +# recorders, the recorded variables, neurons, and synapses can be limited to the ones relevant to the +# experiment, and the recording interval can be increased (see the documentation on the specific recorders). By +# default, recordings are stored in memory but can also be written to file. + +n_record = 1 # number of neurons per type to record dynamic variables from - this script requires n_record >= 1 +n_record_w = 5 # number of senders and targets to record weights from - this script requires n_record_w >=1 + +if n_record == 0 or n_record_w == 0: + raise ValueError("n_record and n_record_w >= 1 required") + +params_mm_reg = { + "interval": duration["step"], # interval between two recorded time points + "record_from": ["V_m", "surrogate_gradient", "learning_signal"], # dynamic variables to record + "start": duration["offset_gen"] + duration["delay_in_rec"], # start time of recording + "label": "multimeter_reg", +} + +params_mm_ad = { + "interval": duration["step"], + "record_from": params_mm_reg["record_from"] + ["V_th_adapt", "adaptation"], + "start": duration["offset_gen"] + duration["delay_in_rec"], + "label": "multimeter_ad", +} + +params_mm_out = { + "interval": duration["step"], + "record_from": ["V_m", "readout_signal", "readout_signal_unnorm", "target_signal", "error_signal"], + "start": duration["total_offset"], + "label": "multimeter_out", +} + +params_wr = { + "senders": nrns_in[:n_record_w] + nrns_rec[:n_record_w], # limit senders to subsample weights to record + "targets": nrns_rec[:n_record_w] + nrns_out, # limit targets to subsample weights to record from + "start": duration["total_offset"], + "label": "weight_recorder", +} + +params_sr_in = { + "start": duration["offset_gen"], + "label": "spike_recorder_in", +} + +params_sr_reg = { + "start": duration["offset_gen"], + "label": "spike_recorder_reg", +} + +params_sr_ad = { + "start": duration["offset_gen"], + "label": "spike_recorder_ad", +} + +#################### + +mm_reg = nest.Create("multimeter", params_mm_reg) +mm_ad = nest.Create("multimeter", params_mm_ad) +mm_out = nest.Create("multimeter", params_mm_out) +sr_in = nest.Create("spike_recorder", params_sr_in) +sr_reg = nest.Create("spike_recorder", params_sr_reg) +sr_ad = nest.Create("spike_recorder", params_sr_ad) +wr = nest.Create("weight_recorder", params_wr) + +nrns_reg_record = nrns_reg[:n_record] +nrns_ad_record = nrns_ad[:n_record] + +# %% ########################################################################################################### +# Create connections +# ~~~~~~~~~~~~~~~~~~ +# Now, we define the connectivity and set up the synaptic parameters, with the synaptic weights drawn from +# normal distributions. After these preparations, we establish the enumerated connections of the core network, +# as well as additional connections to the recorders. + +params_conn_all_to_all = {"rule": "all_to_all", "allow_autapses": False} +params_conn_one_to_one = {"rule": "one_to_one"} + + +def calculate_glorot_dist(fan_in, fan_out): + glorot_scale = 1.0 / max(1.0, (fan_in + fan_out) / 2.0) + glorot_limit = np.sqrt(3.0 * glorot_scale) + glorot_distribution = np.random.uniform(low=-glorot_limit, high=glorot_limit, size=(fan_in, fan_out)) + return glorot_distribution + + +dtype_weights = np.float32 # data type of weights - for reproducing TF results set to np.float32 +weights_in_rec = np.array(np.random.randn(n_in, n_rec).T / np.sqrt(n_in), dtype=dtype_weights) +weights_rec_rec = np.array(np.random.randn(n_rec, n_rec).T / np.sqrt(n_rec), dtype=dtype_weights) +np.fill_diagonal(weights_rec_rec, 0.0) # since no autapses set corresponding weights to zero +weights_rec_out = np.array(calculate_glorot_dist(n_rec, n_out).T, dtype=dtype_weights) +weights_out_rec = np.array(np.random.randn(n_rec, n_out), dtype=dtype_weights) + +params_common_syn_eprop = { + "optimizer": { + "type": "adam", # algorithm to optimize the weights + "batch_size": batch_size, + "beta_1": 0.9, # exponential decay rate for 1st moment estimate of Adam optimizer + "beta_2": 0.999, # exponential decay rate for 2nd moment raw estimate of Adam optimizer + "epsilon": 1e-8, # small numerical stabilization constant of Adam optimizer + "Wmin": -100.0, # pA, minimal limit of the synaptic weights + "Wmax": 100.0, # pA, maximal limit of the synaptic weights + }, + "average_gradient": True, # if True, average the gradient over the learning window + "weight_recorder": wr, +} + +eta_test = 0.0 +eta_train = 5e-3 + +params_syn_base = { + "synapse_model": "eprop_synapse_bsshslm_2020", + "delay": duration["step"], # ms, dendritic delay + "tau_m_readout": params_nrn_out["tau_m"], # ms, for technical reasons pass readout neuron membrane time constant +} + +params_syn_in = params_syn_base.copy() +params_syn_in["weight"] = weights_in_rec # pA, initial values for the synaptic weights + +params_syn_rec = params_syn_base.copy() +params_syn_rec["weight"] = weights_rec_rec + +params_syn_out = params_syn_base.copy() +params_syn_out["weight"] = weights_rec_out + +params_syn_feedback = { + "synapse_model": "eprop_learning_signal_connection_bsshslm_2020", + "delay": duration["step"], + "weight": weights_out_rec, +} + +params_syn_out_out = { + "synapse_model": "rate_connection_delayed", + "delay": duration["step"], + "receptor_type": 1, # receptor type of readout neuron to receive other readout neuron's signals for softmax + "weight": 1.0, # pA, weight 1.0 required for correct softmax computation for technical reasons +} + +params_syn_rate_target = { + "synapse_model": "rate_connection_delayed", + "delay": duration["step"], + "receptor_type": 2, # receptor type over which readout neuron receives target signal +} + +params_syn_static = { + "synapse_model": "static_synapse", + "delay": duration["step"], +} + +params_init_optimizer = { + "optimizer": { + "m": 0.0, # initial 1st moment estimate m of Adam optimizer + "v": 0.0, # initial 2nd moment raw estimate v of Adam optimizer + } +} + +#################### + +nest.SetDefaults("eprop_synapse_bsshslm_2020", params_common_syn_eprop) + +nest.Connect(gen_spk_in, nrns_in, params_conn_one_to_one, params_syn_static) # connection 1 +nest.Connect(nrns_in, nrns_rec, params_conn_all_to_all, params_syn_in) # connection 2 +nest.Connect(nrns_rec, nrns_rec, params_conn_all_to_all, params_syn_rec) # connection 3 +nest.Connect(nrns_rec, nrns_out, params_conn_all_to_all, params_syn_out) # connection 4 +nest.Connect(nrns_out, nrns_rec, params_conn_all_to_all, params_syn_feedback) # connection 5 +nest.Connect(gen_rate_target, nrns_out, params_conn_one_to_one, params_syn_rate_target) # connection 6 +nest.Connect(nrns_out, nrns_out, params_conn_all_to_all, params_syn_out_out) # connection 7 + +nest.Connect(nrns_in, sr_in, params_conn_all_to_all, params_syn_static) +nest.Connect(nrns_reg, sr_reg, params_conn_all_to_all, params_syn_static) +nest.Connect(nrns_ad, sr_ad, params_conn_all_to_all, params_syn_static) + +nest.Connect(mm_reg, nrns_reg_record, params_conn_all_to_all, params_syn_static) +nest.Connect(mm_ad, nrns_ad_record, params_conn_all_to_all, params_syn_static) +nest.Connect(mm_out, nrns_out, params_conn_all_to_all, params_syn_static) + +# After creating the connections, we can individually initialize the optimizer's +# dynamic variables for single synapses (here exemplarily for two connections). + +nest.GetConnections(nrns_rec[0], nrns_rec[1:3]).set([params_init_optimizer] * 2) + +# %% ########################################################################################################### +# Create input and output +# ~~~~~~~~~~~~~~~~~~~~~~~ +# We generate the input as four neuron populations, two producing the left and right cues, respectively, one the +# recall signal and one the background input throughout the task. The sequence of cues is drawn with a +# probability that favors one side. For each such sequence, the favored side, the solution or target, is +# assigned randomly to the left or right. + + +def generate_evidence_accumulation_input_output(batch_size, n_in, steps, input): + n_pop_nrn = n_in // input["n_symbols"] + + prob_choices = np.array([input["prob_group"], 1 - input["prob_group"]], dtype=np.float32) + idx = np.random.choice([0, 1], batch_size) + probs = np.zeros((batch_size, 2), dtype=np.float32) + probs[:, 0] = prob_choices[idx] + probs[:, 1] = prob_choices[1 - idx] + + batched_cues = np.zeros((batch_size, input["n_cues"]), dtype=int) + for b_idx in range(batch_size): + batched_cues[b_idx, :] = np.random.choice([0, 1], input["n_cues"], p=probs[b_idx]) + + input_spike_probs = np.zeros((batch_size, steps["sequence"], n_in)) + + for b_idx in range(batch_size): + for c_idx in range(input["n_cues"]): + cue = batched_cues[b_idx, c_idx] + + step_start = c_idx * (steps["cue"] + steps["spacing"]) + steps["spacing"] + step_stop = step_start + steps["cue"] + + pop_nrn_start = cue * n_pop_nrn + pop_nrn_stop = pop_nrn_start + n_pop_nrn + + input_spike_probs[b_idx, step_start:step_stop, pop_nrn_start:pop_nrn_stop] = input["spike_prob"] + + input_spike_probs[:, -steps["recall"] :, 2 * n_pop_nrn : 3 * n_pop_nrn] = input["spike_prob"] + input_spike_probs[:, :, 3 * n_pop_nrn :] = input["spike_prob"] / 4.0 + input_spike_bools = input_spike_probs > np.random.rand(input_spike_probs.size).reshape(input_spike_probs.shape) + input_spike_bools[:, 0, :] = 0 # remove spikes in 0th time step of every sequence for technical reasons + + target_cues = np.zeros(batch_size, dtype=int) + target_cues[:] = np.sum(batched_cues, axis=1) > int(input["n_cues"] / 2) + + return input_spike_bools, target_cues + + +def get_params_task_input_output(n_iter_interval): + dtype_in_spks = np.float32 # data type of input spikes - for reproducing TF results set to np.float32 + + input_spike_bools_list = [] + target_cues_list = [] + + for _ in range(n_iter_interval): + input_spike_bools, target_cues = generate_evidence_accumulation_input_output(batch_size, n_in, steps, input) + input_spike_bools_list.append(input_spike_bools) + target_cues_list.extend(target_cues) + + input_spike_bools_arr = np.array(input_spike_bools_list).reshape( + n_iter_interval * batch_size * steps["sequence"], n_in + ) + timeline_task = ( + np.arange( + 0.0, + n_iter_interval * batch_size * duration["sequence"], + duration["step"], + ) + + duration["offset_gen"] + ) + + params_gen_spk_in = [ + {"spike_times": timeline_task[input_spike_bools_arr[:, nrn_in_idx]].astype(dtype_in_spks)} + for nrn_in_idx in range(n_in) + ] + + target_rate_changes = np.zeros((n_out, batch_size * n_iter_interval)) + target_rate_changes[np.array(target_cues_list), np.arange(batch_size * n_iter_interval)] = 1 + + params_gen_rate_target = [ + { + "amplitude_times": np.arange( + 0.0, + n_iter_interval * batch_size * duration["sequence"], + duration["sequence"], + ) + + duration["total_offset"], + "amplitude_values": target_rate_changes[nrn_out_idx], + } + for nrn_out_idx in range(n_out) + ] + + return params_gen_spk_in, params_gen_rate_target + + +params_gen_spk_in, params_gen_rate_target = get_params_task_input_output(n_iter_max) + +#################### + +nest.SetStatus(gen_spk_in, params_gen_spk_in) +nest.SetStatus(gen_rate_target, params_gen_rate_target) + +# %% ########################################################################################################### +# Force final update +# ~~~~~~~~~~~~~~~~~~ +# Synapses only get active, that is, the correct weight update calculated and applied, when they transmit a +# spike. To still be able to read out the correct weights at the end of the simulation, we force spiking of the +# presynaptic neuron and thus an update of all synapses, including those that have not transmitted a spike in +# the last update interval, by sending a strong spike to all neurons that form the presynaptic side of an eprop +# synapse. This step is required purely for technical reasons. + +gen_spk_final_update = nest.Create("spike_generator", 1) + +nest.Connect(gen_spk_final_update, nrns_in + nrns_rec, "all_to_all", {"weight": 1000.0}) + +# %% ########################################################################################################### +# Read out pre-training weights +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Before we begin training, we read out the initial weight matrices so that we can eventually compare them to +# the optimized weights. + + +def get_weights(pop_pre, pop_post): + conns = nest.GetConnections(pop_pre, pop_post).get(["source", "target", "weight"]) + conns["senders"] = np.array(conns["source"]) - np.min(conns["source"]) + conns["targets"] = np.array(conns["target"]) - np.min(conns["target"]) + + conns["weight_matrix"] = np.zeros((len(pop_post), len(pop_pre))) + conns["weight_matrix"][conns["targets"], conns["senders"]] = conns["weight"] + return conns + + +weights_pre_train = { + "in_rec": get_weights(nrns_in, nrns_rec), + "rec_rec": get_weights(nrns_rec, nrns_rec), + "rec_out": get_weights(nrns_rec, nrns_out), +} + +# %% ########################################################################################################### +# Simulate and evaluate +# ~~~~~~~~~~~~~~~~~~~~~ +# We train the network by simulating for a number of training iterations with the set learning rate. If early +# stopping is turned on, we evaluate the network's performance on the validation set in regular intervals and, +# if the error is below a certain threshold, we stop the training early. If the error is not below the +# threshold, we continue training until the end of the set number of iterations. Finally, we evaluate the +# network's performance on the test set. +# Furthermore, we evaluate the network's training error by calculating a loss - in this case, the cross-entropy +# error between the integrated recurrent network activity and the target rate. + + +class TrainingPipeline: + def __init__(self): + self.results_dict = { + "error": [], + "loss": [], + "iteration": [], + "label": [], + } + self.n_iter_sim = 0 + self.phase_label_previous = "" + self.error = 0 + self.k_iter = 0 + self.early_stop = False + + def evaluate(self): + events_mm_out = mm_out.get("events") + + readout_signal = events_mm_out["readout_signal"] # corresponds to softmax + target_signal = events_mm_out["target_signal"] + senders = events_mm_out["senders"] + times = events_mm_out["times"] + + cond1 = times > (self.n_iter_sim - 1) * batch_size * duration["sequence"] + duration["total_offset"] + cond2 = times <= self.n_iter_sim * batch_size * duration["sequence"] + duration["total_offset"] + idc = cond1 & cond2 + + readout_signal = np.array([readout_signal[idc][senders[idc] == i] for i in set(senders)]) + target_signal = np.array([target_signal[idc][senders[idc] == i] for i in set(senders)]) + + readout_signal = readout_signal.reshape((n_out, 1, batch_size, steps["sequence"])) + target_signal = target_signal.reshape((n_out, 1, batch_size, steps["sequence"])) + + readout_signal = readout_signal[:, :, :, -steps["learning_window"] :] + target_signal = target_signal[:, :, :, -steps["learning_window"] :] + + loss = -np.mean(np.sum(target_signal * np.log(readout_signal), axis=0), axis=(1, 2)) + + y_prediction = np.argmax(np.mean(readout_signal, axis=3), axis=0) + y_target = np.argmax(np.mean(target_signal, axis=3), axis=0) + accuracy = np.mean((y_target == y_prediction), axis=1) + errors = 1.0 - accuracy + + self.results_dict["iteration"].append(self.n_iter_sim) + self.results_dict["error"].extend(errors) + self.results_dict["loss"].extend(loss) + self.results_dict["label"].append(self.phase_label_previous) + + self.error = errors[0] + + def run(self, phase_label, eta): + params_common_syn_eprop["optimizer"]["eta"] = eta + nest.SetDefaults("eprop_synapse_bsshslm_2020", params_common_syn_eprop) + + nest.Simulate(duration["extension_sim"]) + if self.n_iter_sim > 0: + self.evaluate() + + duration["sim"] = batch_size * duration["sequence"] - duration["extension_sim"] + + nest.Simulate(duration["sim"]) + + self.n_iter_sim += 1 + self.phase_label_previous = phase_label + + def run_training(self): + self.run("training", eta_train) + + def run_validation(self): + if do_early_stopping and self.k_iter % n_validate_every == 0: + self.run("validation", eta_test) + + def run_early_stopping(self): + if do_early_stopping and self.k_iter % n_validate_every == 0: + if self.k_iter > 0 and self.error < stop_crit: + errors_early_stop = [] + for _ in range(n_early_stop): + self.run("early-stopping", eta_test) + errors_early_stop.append(self.error) + + self.early_stop = np.mean(errors_early_stop) < stop_crit + + def run_test(self): + for _ in range(n_test): + self.run("test", eta_test) + + def simulate(self): + nest.Simulate(duration["total_offset"]) + + while self.k_iter < n_iter and not self.early_stop: + self.run_validation() + self.run_early_stopping() + self.run_training() + self.k_iter += 1 + + self.run_test() + + nest.Simulate(steps["extension_sim"]) + + self.evaluate() + + duration["task"] = self.n_iter_sim * batch_size * duration["sequence"] + duration["total_offset"] + + gen_spk_final_update.set({"spike_times": [duration["task"] + duration["extension_sim"] + 1]}) + + nest.Simulate(duration["delays"]) + + def get_results(self): + for k, v in self.results_dict.items(): + self.results_dict[k] = np.array(v) + return self.results_dict + + +training_pipeline = TrainingPipeline() +training_pipeline.simulate() + +results_dict = training_pipeline.get_results() +n_iter_sim = training_pipeline.n_iter_sim + +# %% ########################################################################################################### +# Read out post-training weights +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# After the training, we can read out the optimized final weights. + +weights_post_train = { + "in_rec": get_weights(nrns_in, nrns_rec), + "rec_rec": get_weights(nrns_rec, nrns_rec), + "rec_out": get_weights(nrns_rec, nrns_out), +} + +# %% ########################################################################################################### +# Read out recorders +# ~~~~~~~~~~~~~~~~~~ +# We can also retrieve the recorded history of the dynamic variables and weights, as well as detected spikes. + +events_mm_reg = mm_reg.get("events") +events_mm_ad = mm_ad.get("events") +events_mm_out = mm_out.get("events") +events_sr_in = sr_in.get("events") +events_sr_reg = sr_reg.get("events") +events_sr_ad = sr_ad.get("events") +events_wr = wr.get("events") + +# %% ########################################################################################################### +# Plot results +# ~~~~~~~~~~~~ +# Then, we plot a series of plots. + +do_plotting = True # if True, plot the results + +if not do_plotting: + exit() + +colors = { + "blue": "#2854c5ff", + "red": "#e04b40ff", + "green": "#25aa2cff", + "gold": "#f9c643ff", + "white": "#ffffffff", +} + +plt.rcParams.update( + { + "axes.spines.right": False, + "axes.spines.top": False, + "axes.prop_cycle": cycler(color=[colors[k] for k in ["blue", "red", "green", "gold"]]), + } +) + +# %% ########################################################################################################### +# Plot error +# .......... +# We begin with two plots visualizing the error of the network: the loss and the recall error, both +# plotted against the iterations. + +fig, axs = plt.subplots(2, 1, sharex=True) +fig.suptitle("Training error") + +for color, label in zip(colors, set(results_dict["label"])): + idc = results_dict["label"] == label + axs[0].scatter(results_dict["iteration"][idc], results_dict["loss"][idc], label=label) + axs[1].scatter(results_dict["iteration"][idc], results_dict["error"][idc], label=label) + +axs[0].set_ylabel(r"$E = -\sum_{t,k} \pi_k^{*,t} \log \pi_k^t$") + +axs[1].set_ylabel("recall errors") + +axs[-1].set_xlabel("iteration") +axs[-1].legend(bbox_to_anchor=(1.05, 0.5), loc="center left") +axs[-1].xaxis.get_major_locator().set_params(integer=True) + +fig.tight_layout() + +# %% ########################################################################################################### +# Plot spikes and dynamic variables +# ................................. +# This plotting routine shows how to plot all of the recorded dynamic variables and spikes across time. We take +# one snapshot in the first iteration and one snapshot at the end. + + +def plot_recordable(ax, events, recordable, ylabel, xlims): + for sender in set(events["senders"]): + idc_sender = events["senders"] == sender + idc_times = (events["times"][idc_sender] > xlims[0]) & (events["times"][idc_sender] < xlims[1]) + ax.plot(events["times"][idc_sender][idc_times], events[recordable][idc_sender][idc_times], lw=0.5) + ax.set_ylabel(ylabel) + margin = np.abs(np.max(events[recordable]) - np.min(events[recordable])) * 0.1 + ax.set_ylim(np.min(events[recordable]) - margin, np.max(events[recordable]) + margin) + + +def plot_spikes(ax, events, ylabel, xlims): + idc_times = (events["times"] > xlims[0]) & (events["times"] < xlims[1]) + senders_subset = events["senders"][idc_times] + times_subset = events["times"][idc_times] + + ax.scatter(times_subset, senders_subset, s=0.1) + ax.set_ylabel(ylabel) + margin = np.abs(np.max(senders_subset) - np.min(senders_subset)) * 0.1 + ax.set_ylim(np.min(senders_subset) - margin, np.max(senders_subset) + margin) + + +for title, xlims in zip( + ["Dynamic variables before training", "Dynamic variables after training"], + [ + (0, steps["sequence"]), + ((n_iter_sim - 1) * batch_size * steps["sequence"], n_iter_sim * batch_size * steps["sequence"]), + ], +): + fig, axs = plt.subplots(14, 1, sharex=True, figsize=(8, 14), gridspec_kw={"hspace": 0.4, "left": 0.2}) + fig.suptitle(title) + + plot_spikes(axs[0], events_sr_in, r"$z_i$" + "\n", xlims) + plot_spikes(axs[1], events_sr_reg, r"$z_j$" + "\n", xlims) + + plot_recordable(axs[2], events_mm_reg, "V_m", r"$v_j$" + "\n(mV)", xlims) + plot_recordable(axs[3], events_mm_reg, "surrogate_gradient", r"$\psi_j$" + "\n", xlims) + plot_recordable(axs[4], events_mm_reg, "learning_signal", r"$L_j$" + "\n(pA)", xlims) + + plot_spikes(axs[5], events_sr_ad, r"$z_j$" + "\n", xlims) + + plot_recordable(axs[6], events_mm_ad, "V_m", r"$v_j$" + "\n(mV)", xlims) + plot_recordable(axs[7], events_mm_ad, "surrogate_gradient", r"$\psi_j$" + "\n", xlims) + plot_recordable(axs[8], events_mm_ad, "V_th_adapt", r"$A_j$" + "\n(mV)", xlims) + plot_recordable(axs[9], events_mm_ad, "learning_signal", r"$L_j$" + "\n(pA)", xlims) + + plot_recordable(axs[10], events_mm_out, "V_m", r"$v_k$" + "\n(mV)", xlims) + plot_recordable(axs[11], events_mm_out, "target_signal", r"$\pi^*_k$" + "\n", xlims) + plot_recordable(axs[12], events_mm_out, "readout_signal", r"$\pi_k$" + "\n", xlims) + plot_recordable(axs[13], events_mm_out, "error_signal", r"$\pi_k-\pi^*_k$" + "\n", xlims) + + axs[-1].set_xlabel(r"$t$ (ms)") + axs[-1].set_xlim(*xlims) + + fig.align_ylabels() + +# %% ########################################################################################################### +# Plot weight time courses +# ........................ +# Similarly, we can plot the weight histories. Note that the weight recorder, attached to the synapses, works +# differently than the other recorders. Since synapses only get activated when they transmit a spike, the weight +# recorder only records the weight in those moments. That is why the first weight registrations do not start in +# the first time step and we add the initial weights manually. + + +def plot_weight_time_course(ax, events, nrns_weight_record, label, ylabel): + sender_label, target_label = label.split("_") + nrns_senders = nrns_weight_record[sender_label] + nrns_targets = nrns_weight_record[target_label] + for sender in nrns_senders.tolist(): + for target in nrns_targets.tolist(): + idc_syn = (events["senders"] == sender) & (events["targets"] == target) + idc_syn_pre = (weights_pre_train[label]["source"] == sender) & ( + weights_pre_train[label]["target"] == target + ) + + times = [0.0] + events["times"][idc_syn].tolist() + weights = [weights_pre_train[label]["weight"][idc_syn_pre]] + events["weights"][idc_syn].tolist() + + ax.step(times, weights, c=colors["blue"]) + ax.set_ylabel(ylabel) + ax.set_ylim(-0.6, 0.6) + + +fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4)) +fig.suptitle("Weight time courses") + +nrns_weight_record = { + "in": nrns_in[:n_record_w], + "rec": nrns_rec[:n_record_w], + "out": nrns_out, +} + +plot_weight_time_course(axs[0], events_wr, nrns_weight_record, "in_rec", r"$W_\text{in}$ (pA)") +plot_weight_time_course(axs[1], events_wr, nrns_weight_record, "rec_rec", r"$W_\text{rec}$ (pA)") +plot_weight_time_course(axs[2], events_wr, nrns_weight_record, "rec_out", r"$W_\text{out}$ (pA)") + +axs[-1].set_xlabel(r"$t$ (ms)") +axs[-1].set_xlim(0, duration["task"]) + +fig.align_ylabels() +fig.tight_layout() + +# %% ########################################################################################################### +# Plot weight matrices +# .................... +# If one is not interested in the time course of the weights, it is possible to read out only the initial and +# final weights, which requires less computing time and memory than the weight recorder approach. Here, we plot +# the corresponding weight matrices before and after the optimization. + +cmap = mpl.colors.LinearSegmentedColormap.from_list( + "cmap", ((0.0, colors["blue"]), (0.5, colors["white"]), (1.0, colors["red"])) +) + +fig, axs = plt.subplots(3, 2, sharex="col", sharey="row") +fig.suptitle("Weight matrices") + +all_w_extrema = [] + +for k in weights_pre_train.keys(): + w_pre = weights_pre_train[k]["weight"] + w_post = weights_post_train[k]["weight"] + all_w_extrema.append([np.min(w_pre), np.max(w_pre), np.min(w_post), np.max(w_post)]) + +args = {"cmap": cmap, "vmin": np.min(all_w_extrema), "vmax": np.max(all_w_extrema)} + +for i, weights in zip([0, 1], [weights_pre_train, weights_post_train]): + axs[0, i].pcolormesh(weights["in_rec"]["weight_matrix"].T, **args) + axs[1, i].pcolormesh(weights["rec_rec"]["weight_matrix"], **args) + cmesh = axs[2, i].pcolormesh(weights["rec_out"]["weight_matrix"], **args) + + axs[2, i].set_xlabel("recurrent\nneurons") + +axs[0, 0].set_ylabel("input\nneurons") +axs[1, 0].set_ylabel("recurrent\nneurons") +axs[2, 0].set_ylabel("readout\nneurons") +fig.align_ylabels(axs[:, 0]) + +axs[0, 0].text(0.5, 1.1, "before training", transform=axs[0, 0].transAxes, ha="center") +axs[0, 1].text(0.5, 1.1, "after training", transform=axs[0, 1].transAxes, ha="center") + +axs[2, 0].yaxis.get_major_locator().set_params(integer=True) + +cbar = plt.colorbar(cmesh, cax=axs[1, 1].inset_axes([1.1, 0.2, 0.05, 0.8]), label="weight (pA)") + +fig.tight_layout() + +plt.show() diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_classification_neuromorphic_mnist.png b/pynest/examples/eprop_plasticity/eprop_supervised_classification_neuromorphic_mnist.png new file mode 100644 index 0000000000..72de686527 Binary files /dev/null and b/pynest/examples/eprop_plasticity/eprop_supervised_classification_neuromorphic_mnist.png differ diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_classification_neuromorphic_mnist.py b/pynest/examples/eprop_plasticity/eprop_supervised_classification_neuromorphic_mnist.py new file mode 100644 index 0000000000..e991f8a7f2 --- /dev/null +++ b/pynest/examples/eprop_plasticity/eprop_supervised_classification_neuromorphic_mnist.py @@ -0,0 +1,909 @@ +# -*- coding: utf-8 -*- +# +# eprop_supervised_classification_neuromorphic_mnist.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +r""" +Tutorial on learning N-MNIST classification with e-prop +------------------------------------------------------- + +Training a classification model using supervised e-prop plasticity to classify the Neuromorphic MNIST (N-MNIST) dataset. + +Description +~~~~~~~~~~~ + +This script demonstrates supervised learning of a classification task with the eligibility propagation (e-prop) +plasticity mechanism by Bellec et al. [1]_ with additional biological features described in [3]_. + +The primary objective of this task is to classify the N-MNIST dataset [2]_, an adaptation of the traditional +MNIST dataset of handwritten digits specifically designed for neuromorphic computing. The N-MNIST dataset +captures changes in pixel intensity through a dynamic vision sensor, converting static images into sequences of +binary events, which we interpret as spike trains. This conversion closely emulates biological neural +processing, making it a fitting challenge for an e-prop-equipped spiking neural network (SNN). + +.. image:: eprop_supervised_classification_evidence-accumulation.png + :width: 70 % + :alt: Schematic of network architecture. Same as Figure 1 in the code. + :align: center + +Learning in the neural network model is achieved by optimizing the connection weights with e-prop plasticity. +This plasticity rule requires a specific network architecture depicted in Figure 1. The neural network model +consists of a recurrent network that receives input from Poisson generators and projects onto multiple readout +neurons - one for each class. Each input generator is assigned to a pixel of the input image; when an event is +detected in a pixel at time `t`, the corresponding input generator (connected to an input neuron) emits a spike +at that time. Each readout neuron compares the network signal :math:`y_k` with the teacher signal :math:`y_k^*`, +which it receives from a rate generator representing the respective digit class. Unlike conventional neural +network classifiers that may employ softmax functions and cross-entropy loss for classification, this network +model utilizes a mean-squared error loss to evaluate the training error and perform digit classification. + +Details on the event-based NEST implementation of e-prop can be found in [3]_. + +References +~~~~~~~~~~ + +.. [1] Bellec G, Scherr F, Subramoney F, Hajek E, Salaj D, Legenstein R, Maass W (2020). A solution to the + learning dilemma for recurrent networks of spiking neurons. Nature Communications, 11:3625. + https://doi.org/10.1038/s41467-020-17236-y + +.. [2] Orchard, G., Jayawant, A., Cohen, G. K., & Thakor, N. (2015). Converting static image datasets to + spiking neuromorphic datasets using saccades. Frontiers in neuroscience, 9, 159859. + +.. [3] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Plesser HE, + Dahmen D, Bolten M, Van Albada SJ*, Diesmann M*. Event-based + implementation of eligibility propagation (in preparation) + +""" # pylint: disable=line-too-long # noqa: E501 + +# %% ########################################################################################################### +# Import libraries +# ~~~~~~~~~~~~~~~~ +# We begin by importing all libraries required for the simulation, analysis, and visualization. + +import os +import zipfile + +import matplotlib as mpl +import matplotlib.pyplot as plt +import nest +import numpy as np +import requests +from cycler import cycler +from IPython.display import Image + +# %% ########################################################################################################### +# Schematic of network architecture +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# This figure, identical to the one in the description, shows the required network architecture in the center, +# the input and output of the classification task above, and lists of the required NEST device, neuron, and +# synapse models below. The connections that must be established are numbered 1 to 7. + +try: + Image(filename="./eprop_supervised_classification_neuromorphic_mnist.png") +except Exception: + pass + +# %% ########################################################################################################### +# Setup +# ~~~~~ + +# %% ########################################################################################################### +# Initialize random generator +# ........................... +# We seed the numpy random generator, which will generate random initial weights as well as random input and +# output. + +rng_seed = 1 # numpy random seed +np.random.seed(rng_seed) # fix numpy random seed + +# %% ########################################################################################################### +# Define timing of task +# ..................... +# The task's temporal structure is then defined, once as time steps and once as durations in milliseconds. +# Even though each sample is processed independently during training, we aggregate predictions and true +# labels across a group of samples during the evaluation phase. The number of samples in this group is +# determined by the `group_size` parameter. This data is then used to assess the neural network's +# performance metrics, such as average accuracy and mean error. Increasing the number of iterations enhances +# learning performance up to the point where overfitting occurs. + +group_size = 100 # number of instances over which to evaluate the learning performance +n_iter = 200 # number of iterations +test_every = 10 # cyclical number of training iterations after which to test the performance + +steps = {} + +steps["sequence"] = 300 # time steps of one full sequence +steps["learning_window"] = 10 # time steps of window with non-zero learning signals +steps["evaluation_group"] = group_size * steps["sequence"] +steps["task"] = n_iter * group_size * steps["sequence"] # time steps of task + +steps.update( + { + "offset_gen": 1, # offset since generator signals start from time step 1 + "delay_in_rec": 1, # connection delay between input and recurrent neurons + "extension_sim": 1, # extra time step to close right-open simulation time interval in Simulate() + } +) + +steps["delays"] = steps["delay_in_rec"] # time steps of delays + +steps["total_offset"] = steps["offset_gen"] + steps["delays"] # time steps of total offset +steps["pre_sim"] = steps["total_offset"] + steps["extension_sim"] + +duration = {"step": 1.0} # ms, temporal resolution of the simulation + +duration.update({key: value * duration["step"] for key, value in steps.items()}) # ms, durations + +# %% ########################################################################################################### +# Set up simulation +# ................. +# As last step of the setup, we reset the NEST kernel to remove all existing NEST simulation settings and +# objects and set some NEST kernel parameters. + +params_setup = { + "print_time": False, # if True, print time progress bar during simulation, set False if run as code cell + "resolution": duration["step"], + "total_num_virtual_procs": 4, # number of virtual processes, set in case of distributed computing +} + +#################### + +nest.ResetKernel() +nest.set(**params_setup) +nest.set_verbosity("M_FATAL") + +# %% ########################################################################################################### +# Create neurons +# ~~~~~~~~~~~~~~ +# We proceed by creating a certain number of input, recurrent, and readout neurons and setting their parameters. +# Additionally, we already create an input spike generator and an output target rate generator, which we will +# configure later. Each input sample, featuring two channels, is mapped out to a 34x34 pixel grid. We allocate +# Poisson generators to each input image pixel to simulate spike events. However, due to the observation +# that some pixels either never record events or do so infrequently, we maintain a blocklist of these inactive +# pixels. By omitting Poisson generators for pixels on this blocklist, we effectively reduce the total number of +# input neurons and Poisson generators required, optimizing the network's resource usage. + +pixels_blocklist = np.loadtxt("./NMNIST_pixels_blocklist.txt") + +n_in = 2 * 34 * 34 - len(pixels_blocklist) # number of input neurons +n_rec = 150 # number of recurrent neurons +n_out = 10 # number of readout neurons + +params_nrn_out = { + "C_m": 1.0, # pF, membrane capacitance - takes effect only if neurons get current input (here not the case) + "E_L": 0.0, # mV, leak / resting membrane potential + "eprop_isi_trace_cutoff": 100, # cutoff of integration of eprop trace between spikes + "I_e": 0.0, # pA, external current input + "regular_spike_arrival": False, # If True, input spikes arrive at end of time step, if False at beginning + "tau_m": 100.0, # ms, membrane time constant + "V_m": 0.0, # mV, initial value of the membrane voltage +} + +params_nrn_rec = { + "beta": 1.7, # width scaling of the pseudo-derivative + "C_m": 1.0, + "c_reg": 2.0 / duration["sequence"], # coefficient of firing rate regularization + "E_L": 0.0, + "eprop_isi_trace_cutoff": 100, + "f_target": 10.0, # spikes/s, target firing rate for firing rate regularization + "gamma": 0.5, # height scaling of the pseudo-derivative + "I_e": 0.0, + "kappa": 0.99, # low-pass filter of the eligibility trace + "kappa_reg": 0.99, # low-pass filter of the firing rate for regularization + "regular_spike_arrival": True, + "surrogate_gradient_function": "piecewise_linear", # surrogate gradient / pseudo-derivative function + "t_ref": 0.0, # ms, duration of refractory period + "tau_m": 30.0, + "V_m": 0.0, + "V_th": 0.6, # mV, spike threshold membrane voltage +} + +#################### + +# Intermediate parrot neurons required between input spike generators and recurrent neurons, +# since devices cannot establish plastic synapses for technical reasons + +gen_spk_in = nest.Create("spike_generator", n_in) +nrns_in = nest.Create("parrot_neuron", n_in) + +nrns_rec = nest.Create("eprop_iaf", n_rec, params_nrn_rec) +nrns_out = nest.Create("eprop_readout", n_out, params_nrn_out) +gen_rate_target = nest.Create("step_rate_generator", n_out) +gen_learning_window = nest.Create("step_rate_generator") + +# %% ########################################################################################################### +# Create recorders +# ~~~~~~~~~~~~~~~~ +# We also create recorders, which, while not required for the training, will allow us to track various dynamic +# variables of the neurons, spikes, and changes in synaptic weights. To save computing time and memory, the +# recorders, the recorded variables, neurons, and synapses can be limited to the ones relevant to the +# experiment, and the recording interval can be increased (see the documentation on the specific recorders). By +# default, recordings are stored in memory but can also be written to file. + +n_record = 1 # number of neurons to record dynamic variables from - this script requires n_record >= 1 +n_record_w = 5 # number of senders and targets to record weights from - this script requires n_record_w >=1 + +if n_record == 0 or n_record_w == 0: + raise ValueError("n_record and n_record_w >= 1 required") + +params_mm_rec = { + "interval": duration["step"], # interval between two recorded time points + "record_from": ["V_m", "surrogate_gradient", "learning_signal"], # dynamic variables to record + "start": duration["offset_gen"] + duration["delay_in_rec"], # start time of recording + "stop": duration["offset_gen"] + duration["delay_in_rec"] + duration["task"], # stop time of recording + "label": "multimeter_rec", +} + +params_mm_out = { + "interval": duration["step"], + "record_from": ["V_m", "readout_signal", "target_signal", "error_signal"], + "start": duration["total_offset"], + "stop": duration["total_offset"] + duration["task"], + "label": "multimeter_out", +} + +params_wr = { + "senders": nrns_in[:n_record_w] + nrns_rec[:n_record_w], # limit senders to subsample weights to record + "targets": nrns_rec[:n_record_w] + nrns_out, # limit targets to subsample weights to record from + "start": duration["total_offset"], + "stop": duration["total_offset"] + duration["task"], + "label": "weight_recorder", +} + +params_sr_in = { + "start": duration["offset_gen"], + "stop": duration["total_offset"] + duration["task"], + "label": "spike_recorder_in", +} + +params_sr_rec = { + "start": duration["offset_gen"], + "stop": duration["total_offset"] + duration["task"], + "label": "spike_recorder_rec", +} + +#################### + +mm_rec = nest.Create("multimeter", params_mm_rec) +mm_out = nest.Create("multimeter", params_mm_out) +sr_in = nest.Create("spike_recorder", params_sr_in) +sr_rec = nest.Create("spike_recorder", params_sr_rec) +wr = nest.Create("weight_recorder", params_wr) + +nrns_rec_record = nrns_rec[:n_record] + +# %% ########################################################################################################### +# Create connections +# ~~~~~~~~~~~~~~~~~~ +# Now, we define the connectivity and set up the synaptic parameters, with the synaptic weights drawn from +# normal distributions. After these preparations, we establish the enumerated connections of the core network, +# as well as additional connections to the recorders. For this task, we implement a method characterized by +# sparse connectivity designed to enhance resource efficiency during the learning phase. This method involves +# the creation of binary masks that reflect predetermined levels of sparsity across various network connections, +# namely from input-to-recurrent, recurrent-to-recurrent, and recurrent-to-output. These binary masks are +# applied directly to the corresponding weight matrices. Subsequently, we activate only connections +# corresponding to non-zero weights to achieve the targeted sparsity level. For instance, a sparsity level of +# 0.9 means that most connections are turned off. This approach reduces resource consumption and, ideally, +# boosts the learning process's efficiency. + +params_conn_all_to_all = {"rule": "all_to_all", "allow_autapses": False} +params_conn_one_to_one = {"rule": "one_to_one"} + + +def calculate_glorot_dist(fan_in, fan_out): + glorot_scale = 1.0 / max(1.0, (fan_in + fan_out) / 2.0) + glorot_limit = np.sqrt(3.0 * glorot_scale) + glorot_distribution = np.random.uniform(low=-glorot_limit, high=glorot_limit, size=(fan_in, fan_out)) + return glorot_distribution + + +def create_mask(weights, sparsity_level): + return np.random.choice([0, 1], weights.shape, p=[sparsity_level, 1 - sparsity_level]) + + +dtype_weights = np.float32 # data type of weights - for reproducing TF results set to np.float32 +weights_in_rec = np.array(np.random.randn(n_in, n_rec).T / np.sqrt(n_in), dtype=dtype_weights) +weights_rec_rec = np.array(np.random.randn(n_rec, n_rec).T / np.sqrt(n_rec), dtype=dtype_weights) +np.fill_diagonal(weights_rec_rec, 0.0) # since no autapses set corresponding weights to zero +weights_rec_out = np.array(calculate_glorot_dist(n_rec, n_out).T, dtype=dtype_weights) +weights_out_rec = np.array(np.random.randn(n_rec, n_out), dtype=dtype_weights) + +weights_in_rec *= create_mask(weights_in_rec, 0.75) +weights_rec_rec *= create_mask(weights_rec_rec, 0.99) +weights_rec_out *= create_mask(weights_rec_out, 0.0) + +params_common_syn_eprop = { + "optimizer": { + "type": "gradient_descent", # algorithm to optimize the weights + "batch_size": 1, + "eta": 5e-3, # learning rate + "optimize_each_step": False, # call optimizer every time step (True) or once per spike (False); both + # yield same results for gradient descent, False offers speed-up + "Wmin": -100.0, # pA, minimal limit of the synaptic weights + "Wmax": 100.0, # pA, maximal limit of the synaptic weights + }, + "weight_recorder": wr, +} + +eta_train = 5e-3 +eta_test = 0.0 + +params_syn_base = { + "synapse_model": "eprop_synapse", + "delay": duration["step"], # ms, dendritic delay +} + +params_syn_in = params_syn_base.copy() +params_syn_rec = params_syn_base.copy() +params_syn_out = params_syn_base.copy() + +params_syn_feedback = { + "synapse_model": "eprop_learning_signal_connection", + "delay": duration["step"], + "weight": weights_out_rec, +} + +params_syn_learning_window = { + "synapse_model": "rate_connection_delayed", + "delay": duration["step"], + "receptor_type": 1, # receptor type over which readout neuron receives learning window signal +} + +params_syn_rate_target = { + "synapse_model": "rate_connection_delayed", + "delay": duration["step"], + "receptor_type": 2, # receptor type over which readout neuron receives target signal +} + +params_syn_static = { + "synapse_model": "static_synapse", + "delay": duration["step"], +} + +params_init_optimizer = { + "optimizer": { + "m": 0.0, # initial 1st moment estimate m of Adam optimizer + "v": 0.0, # initial 2nd moment raw estimate v of Adam optimizer + } +} + +#################### + +nest.SetDefaults("eprop_synapse", params_common_syn_eprop) + +nest.Connect(gen_spk_in, nrns_in, params_conn_one_to_one, params_syn_static) # connection 1 + + +def sparsely_connect(weights, params_syn, nrns_pre, nrns_post): + for j in range(weights.shape[0]): + for i in range(weights.shape[1]): + w = weights[j, i] + if np.abs(w) > 0.0: + params_syn["weight"] = w + nest.Connect(nrns_pre[i], nrns_post[j], params_conn_one_to_one, params_syn) + + +sparsely_connect(weights_in_rec, params_syn_in, nrns_in, nrns_rec) # connection 2 +sparsely_connect(weights_rec_rec, params_syn_rec, nrns_rec, nrns_rec) # connection 3 +sparsely_connect(weights_rec_out, params_syn_out, nrns_rec, nrns_out) # connection 4 + +nest.Connect(nrns_out, nrns_rec, params_conn_all_to_all, params_syn_feedback) # connection 5 +nest.Connect(gen_rate_target, nrns_out, params_conn_one_to_one, params_syn_rate_target) # connection 6 +nest.Connect(gen_learning_window, nrns_out, params_conn_all_to_all, params_syn_learning_window) # connection 7 + +nest.Connect(nrns_in, sr_in, params_conn_all_to_all, params_syn_static) +nest.Connect(nrns_rec, sr_rec, params_conn_all_to_all, params_syn_static) + +nest.Connect(mm_rec, nrns_rec_record, params_conn_all_to_all, params_syn_static) +nest.Connect(mm_out, nrns_out, params_conn_all_to_all, params_syn_static) + +# After creating the connections, we can individually initialize the optimizer's +# dynamic variables for single synapses (here exemplarily for two connections). + +nest.GetConnections(nrns_rec[0], nrns_rec[1:3]).set([params_init_optimizer] * 2) + +# %% ########################################################################################################### +# Create input and output +# ~~~~~~~~~~~~~~~~~~~~~~~ +# This section involves downloading the N-MNIST dataset, extracting it, and preparing it for neural network +# training and testing. The dataset consists of two main components: training and test sets. + +# The `download_and_extract_nmnist_dataset` function retrieves the dataset from its public repository and +# extracts it into a specified directory. It checks for the presence of the dataset to avoid re-downloading. +# After downloading, it extracts the main dataset zip file, followed by further extraction of nested zip files +# for training and test data, ensuring that the dataset is ready for loading and processing. + +# The `load_image` function reads a single image file from the dataset, converting the event-based neuromorphic +# data into a format suitable for processing by spiking neural networks. It filters events based on specified +# pixel blocklists, arranging the remaining events into a structured format representing the image. + +# The `DataLoader` class facilitates the loading of the dataset for neural network training and testing. It +# supports selecting specific labels for inclusion, allowing for targeted training on subsets of the dataset. +# The class also includes functionality for random shuffling and grouping of data, ensuring diverse and +# representative samples are used throughout the training process. + + +def unzip(zip_file_path, extraction_path): + print(f"Extracting {zip_file_path}.") + with zipfile.ZipFile(zip_file_path, "r") as zip_file: + zip_file.extractall(extraction_path) + os.remove(zip_file_path) + + +def download_and_extract_nmnist_dataset(save_path="./"): + nmnist_dataset = { + "url": "https://prod-dcd-datasets-cache-zipfiles.s3.eu-west-1.amazonaws.com/468j46mzdv-1.zip", + "directory": "468j46mzdv-1", + "zip": "dataset.zip", + } + + path = os.path.join(save_path, nmnist_dataset["directory"]) + + train_path = os.path.join(path, "Train") + test_path = os.path.join(path, "Test") + + downloaded_zip_path = os.path.join(save_path, nmnist_dataset["zip"]) + + if os.path.exists(path) and os.path.exists(train_path) and os.path.exists(test_path): + print(f"\nThe directory '{path}' already exists with expected contents. Skipping download and extraction.") + else: + if not os.path.exists(downloaded_zip_path): + print("\nDownloading the N-MNIST dataset.") + response = requests.get(nmnist_dataset["url"], timeout=10) + with open(downloaded_zip_path, "wb") as file: + file.write(response.content) + + unzip(downloaded_zip_path, save_path) + unzip(f"{train_path}.zip", path) + unzip(f"{test_path}.zip", path) + + return train_path, test_path + + +def load_image(file_path, pixels_blocklist=None): + with open(file_path, "rb") as file: + inputByteArray = file.read() + byte_array = np.asarray([x for x in inputByteArray]) + + x_coords = byte_array[0::5] + y_coords = byte_array[1::5] + polarities = byte_array[2::5] >> 7 + times = ((byte_array[2::5] << 16) | (byte_array[3::5] << 8) | byte_array[4::5]) & 0x7FFFFF + times = np.clip(times // 1000, 1, 299) + + image_full = [[] for _ in range(2 * 34 * 34)] + image = [] + + for polarity, x, y, time in zip(polarities, y_coords, x_coords, times): + pixel = polarity * 34 * 34 + x * 34 + y + image_full[pixel].append(time) + + for pixel, times in enumerate(image_full): + if pixels_blocklist is None or pixel not in pixels_blocklist: + image.append(times) + + return image + + +class DataLoader: + def __init__(self, path, selected_labels, group_size, pixels_blocklist=None): + self.path = path + self.selected_labels = selected_labels + self.group_size = group_size + self.pixels_blocklist = pixels_blocklist + + self.current_index = 0 + self.all_sample_paths, self.all_labels = self.get_all_sample_paths_with_labels() + self.n_all_samples = len(self.all_sample_paths) + self.shuffled_indices = np.random.permutation(self.n_all_samples) + + def get_all_sample_paths_with_labels(self): + all_sample_paths = [] + all_labels = [] + + for label in self.selected_labels: + label_dir_path = os.path.join(self.path, str(label)) + all_files = os.listdir(label_dir_path) + + for sample in all_files: + all_sample_paths.append(os.path.join(label_dir_path, sample)) + all_labels.append(label) + + return all_sample_paths, all_labels + + def get_new_evaluation_group(self): + end_index = self.current_index + self.group_size + + selected_indices = np.take(self.shuffled_indices, range(self.current_index, end_index), mode="wrap") + + self.current_index = (self.current_index + self.group_size) % self.n_all_samples + + images_group = [load_image(self.all_sample_paths[i], self.pixels_blocklist) for i in selected_indices] + labels_group = [self.all_labels[i] for i in selected_indices] + + return images_group, labels_group + + +def create_input_output(loader, t_start_iteration, t_end_iteration, target_signal_value=1.0): + img_group, targets_group = loader.get_new_evaluation_group() + + spike_times = [[] for _ in range(n_in)] + target_rates = np.zeros((n_out, steps["evaluation_group"])) + + for group_elem in range(group_size): + t_start_group_elem = group_elem * steps["sequence"] + t_end_group_elem = t_start_group_elem + steps["sequence"] + t_start_absolute = t_start_iteration + t_start_group_elem + + target_rates[targets_group[group_elem], t_start_group_elem:t_end_group_elem] = target_signal_value + + for n, relative_times in enumerate(img_group[group_elem]): + if len(relative_times) > 0: + spike_times[n].extend(t_start_absolute + np.array(relative_times)) + + params_gen_spk_in = [{"spike_times": spk_times} for spk_times in spike_times] + + amplitude_times = duration["total_offset"] + np.arange(t_start_iteration, t_end_iteration) + + params_gen_rate_target = [ + {"amplitude_times": amplitude_times, "amplitude_values": target_rate} for target_rate in target_rates + ] + return params_gen_spk_in, params_gen_rate_target + + +save_path = "./" # path to save the N-MNIST dataset to +train_path, test_path = download_and_extract_nmnist_dataset(save_path) + +selected_labels = [label for label in range(n_out)] + +data_loader_train = DataLoader(train_path, selected_labels, group_size, pixels_blocklist) +data_loader_test = DataLoader(test_path, selected_labels, group_size, pixels_blocklist) + +amplitude_times = np.hstack( + [ + np.array([0.0, duration["sequence"] - duration["learning_window"]]) + + duration["total_offset"] + + i * duration["sequence"] + for i in range(group_size * n_iter) + ] +) + +amplitude_values = np.array([0.0, 1.0] * group_size * n_iter) + +params_gen_learning_window = { + "amplitude_times": amplitude_times, + "amplitude_values": amplitude_values, +} + +# %% ########################################################################################################### +# Force final update +# ~~~~~~~~~~~~~~~~~~ +# Synapses only get active, that is, the correct weight update calculated and applied, when they transmit a +# spike. To still be able to read out the correct weights at the end of the simulation, we force spiking of the +# presynaptic neuron and thus an update of all synapses, including those that have not transmitted a spike in +# the last update interval, by sending a strong spike to all neurons that form the presynaptic side of an eprop +# synapse. This step is required purely for technical reasons. + +gen_spk_final_update = nest.Create("spike_generator", 1, {"spike_times": [duration["task"] + duration["delays"]]}) + +nest.Connect(gen_spk_final_update, nrns_in + nrns_rec, "all_to_all", {"weight": 1000.0}) + +# %% ########################################################################################################### +# Read out pre-training weights +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Before we begin training, we read out the initial weight matrices so that we can eventually compare them to +# the optimized weights. + + +def get_weights(pop_pre, pop_post): + conns = nest.GetConnections(pop_pre, pop_post).get(["source", "target", "weight"]) + conns["senders"] = np.array(conns["source"]) - np.min(conns["source"]) + conns["targets"] = np.array(conns["target"]) - np.min(conns["target"]) + + conns["weight_matrix"] = np.zeros((len(pop_post), len(pop_pre))) + conns["weight_matrix"][conns["targets"], conns["senders"]] = conns["weight"] + return conns + + +weights_pre_train = { + "in_rec": get_weights(nrns_in, nrns_rec), + "rec_rec": get_weights(nrns_rec, nrns_rec), + "rec_out": get_weights(nrns_rec, nrns_out), +} + +# %% ########################################################################################################### +# Simulate +# ~~~~~~~~ +# We train the network by simulating for a set simulation time, determined by the number of iterations and the +# evaluation group size and the length of one sequence. + + +def evaluate(n_iteration, iter_start): + events_mm_out = mm_out.get("events") + + readout_signal = events_mm_out["readout_signal"] + target_signal = events_mm_out["target_signal"] + senders = events_mm_out["senders"] + + readout_signal = np.array([readout_signal[senders == i] for i in set(senders)]) + target_signal = np.array([target_signal[senders == i] for i in set(senders)]) + + readout_signal = readout_signal.reshape((n_out, n_iteration, group_size, steps["sequence"])) + target_signal = target_signal.reshape((n_out, n_iteration, group_size, steps["sequence"])) + + readout_signal = readout_signal[:, iter_start:, :, -steps["learning_window"] :] + target_signal = target_signal[:, iter_start:, :, -steps["learning_window"] :] + + loss = 0.5 * np.mean(np.sum((readout_signal - target_signal) ** 2, axis=3), axis=(0, 2)) + + y_prediction = np.argmax(np.mean(readout_signal, axis=3), axis=0) + y_target = np.argmax(np.mean(target_signal, axis=3), axis=0) + accuracy = np.mean((y_target == y_prediction), axis=1) + recall_errors = 1.0 - accuracy + + return loss, accuracy, recall_errors + + +nest.Simulate(duration["pre_sim"]) + +nest.SetStatus(gen_learning_window, params_gen_learning_window) + +for iteration in range(n_iter): + t_start_iteration = iteration * duration["evaluation_group"] + t_end_iteration = t_start_iteration + duration["evaluation_group"] + + if iteration != 0 and iteration % test_every == 0: + loader, eta = data_loader_test, eta_test + else: + loader, eta = data_loader_train, eta_train + + params_common_syn_eprop["optimizer"]["eta"] = eta + nest.SetDefaults("eprop_synapse", params_common_syn_eprop) + + params_gen_spk_in, params_gen_rate_target = create_input_output(loader, t_start_iteration, t_end_iteration) + + nest.SetStatus(gen_spk_in, params_gen_spk_in) + nest.SetStatus(gen_rate_target, params_gen_rate_target) + nest.Simulate(duration["evaluation_group"]) + + loss, accuracy, recall_errors = evaluate(iteration + 1, -1) + + print(f" iteration: {iteration} loss: {loss[0]:0.5f} accuracy: {accuracy[0]:0.5f}") + +# %% ########################################################################################################### +# Read out post-training weights +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# After the training, we can read out the optimized final weights. + +weights_post_train = { + "in_rec": get_weights(nrns_in, nrns_rec), + "rec_rec": get_weights(nrns_rec, nrns_rec), + "rec_out": get_weights(nrns_rec, nrns_out), +} + +# %% ########################################################################################################### +# Read out recorders +# ~~~~~~~~~~~~~~~~~~ +# We can also retrieve the recorded history of the dynamic variables and weights, as well as detected spikes. + +events_mm_rec = mm_rec.get("events") +events_mm_out = mm_out.get("events") +events_sr_in = sr_in.get("events") +events_sr_rec = sr_rec.get("events") +events_wr = wr.get("events") + +# %% ########################################################################################################### +# Evaluate training error +# ~~~~~~~~~~~~~~~~~~~~~~~ +# We evaluate the network's training error by calculating a loss - in this case, the mean squared error between +# the integrated recurrent network activity and the target rate. + +loss, accuracy, recall_errors = evaluate(n_iter, 0) + +# %% ########################################################################################################### +# Plot results +# ~~~~~~~~~~~~ +# Then, we plot a series of plots. + +do_plotting = True # if True, plot the results + +if not do_plotting: + exit() + +colors = { + "blue": "#2854c5ff", + "red": "#e04b40ff", + "white": "#ffffffff", +} + +plt.rcParams.update( + { + "axes.spines.right": False, + "axes.spines.top": False, + "axes.prop_cycle": cycler(color=[colors["blue"], colors["red"]]), + } +) + +# %% ########################################################################################################### +# Plot training error +# ................... +# We begin with two plots visualizing the training error of the network: the loss and the recall error, both +# plotted against the iterations. + +fig, axs = plt.subplots(2, 1, sharex=True) +fig.suptitle("Training error") + +axs[0].plot(range(1, n_iter + 1), loss) +axs[0].set_ylabel(r"$E = \frac{1}{2} \sum_{t,k} \left( y_k^t -y_k^{*,t}\right)^2$") + +axs[1].plot(range(1, n_iter + 1), recall_errors) +axs[1].set_ylabel("recall errors") + +axs[-1].set_xlabel("training iteration") +axs[-1].set_xlim(1, n_iter) +axs[-1].xaxis.get_major_locator().set_params(integer=True) + +fig.tight_layout() + +# %% ########################################################################################################### +# Plot spikes and dynamic variables +# ................................. +# This plotting routine shows how to plot all of the recorded dynamic variables and spikes across time. We take +# one snapshot in the first iteration and one snapshot at the end. + + +def plot_recordable(ax, events, recordable, ylabel, xlims): + for sender in set(events["senders"]): + idc_sender = events["senders"] == sender + idc_times = (events["times"][idc_sender] > xlims[0]) & (events["times"][idc_sender] < xlims[1]) + ax.plot(events["times"][idc_sender][idc_times], events[recordable][idc_sender][idc_times], lw=0.5) + ax.set_ylabel(ylabel) + margin = np.abs(np.max(events[recordable]) - np.min(events[recordable])) * 0.1 + ax.set_ylim(np.min(events[recordable]) - margin, np.max(events[recordable]) + margin) + + +def plot_spikes(ax, events, ylabel, xlims): + idc_times = (events["times"] > xlims[0]) & (events["times"] < xlims[1]) + senders_subset = events["senders"][idc_times] + times_subset = events["times"][idc_times] + + ax.scatter(times_subset, senders_subset, s=0.1) + ax.set_ylabel(ylabel) + margin = np.abs(np.max(senders_subset) - np.min(senders_subset)) * 0.1 + ax.set_ylim(np.min(senders_subset) - margin, np.max(senders_subset) + margin) + + +for title, xlims in zip( + ["Dynamic variables before training", "Dynamic variables after training"], + [ + (steps["pre_sim"], steps["pre_sim"] + steps["sequence"]), + (steps["pre_sim"] + steps["task"] - steps["sequence"], steps["pre_sim"] + steps["task"]), + ], +): + fig, axs = plt.subplots(9, 1, sharex=True, figsize=(8, 14), gridspec_kw={"hspace": 0.4, "left": 0.2}) + fig.suptitle(title) + + plot_spikes(axs[0], events_sr_in, r"$z_i$" + "\n", xlims) + plot_spikes(axs[1], events_sr_rec, r"$z_j$" + "\n", xlims) + + plot_recordable(axs[2], events_mm_rec, "V_m", r"$v_j$" + "\n(mV)", xlims) + plot_recordable(axs[3], events_mm_rec, "surrogate_gradient", r"$\psi_j$" + "\n", xlims) + plot_recordable(axs[4], events_mm_rec, "learning_signal", r"$L_j$" + "\n(pA)", xlims) + + plot_recordable(axs[5], events_mm_out, "V_m", r"$v_k$" + "\n(mV)", xlims) + plot_recordable(axs[6], events_mm_out, "target_signal", r"$y^*_k$" + "\n", xlims) + plot_recordable(axs[7], events_mm_out, "readout_signal", r"$y_k$" + "\n", xlims) + plot_recordable(axs[8], events_mm_out, "error_signal", r"$y_k-y^*_k$" + "\n", xlims) + + axs[-1].set_xlabel(r"$t$ (ms)") + axs[-1].set_xlim(*xlims) + + fig.align_ylabels() + +# %% ########################################################################################################### +# Plot weight time courses +# ........................ +# Similarly, we can plot the weight histories. Note that the weight recorder, attached to the synapses, works +# differently than the other recorders. Since synapses only get activated when they transmit a spike, the weight +# recorder only records the weight in those moments. That is why the first weight registrations do not start in +# the first time step and we add the initial weights manually. + + +def plot_weight_time_course(ax, events, nrns_weight_record, label, ylabel): + sender_label, target_label = label.split("_") + nrns_senders = nrns_weight_record[sender_label] + nrns_targets = nrns_weight_record[target_label] + for sender in nrns_senders.tolist(): + for target in nrns_targets.tolist(): + idc_syn = (events["senders"] == sender) & (events["targets"] == target) + idc_syn_pre = (weights_pre_train[label]["source"] == sender) & ( + weights_pre_train[label]["target"] == target + ) + + times = [0.0] + events["times"][idc_syn].tolist() + weights = [weights_pre_train[label]["weight"][idc_syn_pre]] + events["weights"][idc_syn].tolist() + + ax.step(times, weights, c=colors["blue"]) + ax.set_ylabel(ylabel) + ax.set_ylim(-0.6, 0.6) + + +fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4)) +fig.suptitle("Weight time courses") + +nrns_weight_record = { + "in": nrns_in[:n_record_w], + "rec": nrns_rec[:n_record_w], + "out": nrns_out, +} + +plot_weight_time_course(axs[0], events_wr, nrns_weight_record, "in_rec", r"$W_\text{in}$ (pA)") +plot_weight_time_course(axs[1], events_wr, nrns_weight_record, "rec_rec", r"$W_\text{rec}$ (pA)") +plot_weight_time_course(axs[2], events_wr, nrns_weight_record, "rec_out", r"$W_\text{out}$ (pA)") + +axs[-1].set_xlabel(r"$t$ (ms)") +axs[-1].set_xlim(0, steps["task"]) + +fig.align_ylabels() +fig.tight_layout() + +# %% ########################################################################################################### +# Plot weight matrices +# .................... +# If one is not interested in the time course of the weights, it is possible to read out only the initial and +# final weights, which requires less computing time and memory than the weight recorder approach. Here, we plot +# the corresponding weight matrices before and after the optimization. + +cmap = mpl.colors.LinearSegmentedColormap.from_list( + "cmap", ((0.0, colors["blue"]), (0.5, colors["white"]), (1.0, colors["red"])) +) + +fig, axs = plt.subplots(3, 2, sharex="col", sharey="row") +fig.suptitle("Weight matrices") + +all_w_extrema = [] + +for k in weights_pre_train.keys(): + w_pre = weights_pre_train[k]["weight"] + w_post = weights_post_train[k]["weight"] + all_w_extrema.append([np.min(w_pre), np.max(w_pre), np.min(w_post), np.max(w_post)]) + +args = {"cmap": cmap, "vmin": np.min(all_w_extrema), "vmax": np.max(all_w_extrema)} + +for i, weights in zip([0, 1], [weights_pre_train, weights_post_train]): + axs[0, i].pcolormesh(weights["in_rec"]["weight_matrix"].T, **args) + axs[1, i].pcolormesh(weights["rec_rec"]["weight_matrix"], **args) + cmesh = axs[2, i].pcolormesh(weights["rec_out"]["weight_matrix"], **args) + + axs[2, i].set_xlabel("recurrent\nneurons") + +axs[0, 0].set_ylabel("input\nneurons") +axs[1, 0].set_ylabel("recurrent\nneurons") +axs[2, 0].set_ylabel("readout\nneurons") +fig.align_ylabels(axs[:, 0]) + +axs[0, 0].text(0.5, 1.1, "before training", transform=axs[0, 0].transAxes, ha="center") +axs[0, 1].text(0.5, 1.1, "after training", transform=axs[0, 1].transAxes, ha="center") + +axs[2, 0].yaxis.get_major_locator().set_params(integer=True) + +cbar = plt.colorbar(cmesh, cax=axs[1, 1].inset_axes([1.1, 0.2, 0.05, 0.8]), label="weight (pA)") + +fig.tight_layout() + +plt.show() diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_classification_schematic_evidence-accumulation.png b/pynest/examples/eprop_plasticity/eprop_supervised_classification_schematic_evidence-accumulation.png deleted file mode 100644 index 60738a57b2..0000000000 Binary files a/pynest/examples/eprop_plasticity/eprop_supervised_classification_schematic_evidence-accumulation.png and /dev/null differ diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_regression_handwriting_bsshslm_2020.png b/pynest/examples/eprop_plasticity/eprop_supervised_regression_handwriting_bsshslm_2020.png new file mode 100644 index 0000000000..d427ef89fd Binary files /dev/null and b/pynest/examples/eprop_plasticity/eprop_supervised_regression_handwriting_bsshslm_2020.png differ diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_regression_handwriting.py b/pynest/examples/eprop_plasticity/eprop_supervised_regression_handwriting_bsshslm_2020.py similarity index 84% rename from pynest/examples/eprop_plasticity/eprop_supervised_regression_handwriting.py rename to pynest/examples/eprop_plasticity/eprop_supervised_regression_handwriting_bsshslm_2020.py index 381d878848..a22ab3e822 100644 --- a/pynest/examples/eprop_plasticity/eprop_supervised_regression_handwriting.py +++ b/pynest/examples/eprop_plasticity/eprop_supervised_regression_handwriting_bsshslm_2020.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # -# eprop_supervised_regression_handwriting.py +# eprop_supervised_regression_handwriting_bsshslm_2020.py # # This file is part of NEST. # @@ -20,8 +20,8 @@ # along with NEST. If not, see . r""" -Tutorial on learning to generate handwritten text with e-prop -------------------------------------------------------------- +Tutorial on learning to generate handwritten text with e-prop after Bellec et al. (2020) +---------------------------------------------------------------------------------------- Training a regression model using supervised e-prop plasticity to generate handwritten text @@ -34,14 +34,13 @@ This type of learning is demonstrated at the proof-of-concept task in [1]_. We based this script on their TensorFlow script given in [2]_ and changed the task as well as the parameters slightly. - In this task, the network learns to generate an arbitrary N-dimensional temporal pattern. Here, the network -learns to reproduce with its overall spiking activity a two-dimensional, roughly one-second-long target signal +learns to reproduce with its overall spiking activity a two-dimensional, roughly two-second-long target signal which encode the x and y coordinates of the handwritten word "chaos". -.. image:: eprop_supervised_regression_schematic_handwriting.png +.. image:: eprop_supervised_regression_handwriting_bsshslm_2020.png :width: 70 % - :alt: See Figure 1 below. + :alt: Schematic of network architecture. Same as Figure 1 in the code. :align: center Learning in the neural network model is achieved by optimizing the connection weights with e-prop plasticity. @@ -76,8 +75,10 @@ .. [2] https://github.com/IGITUGraz/eligibility_propagation/blob/master/Figure_3_and_S7_e_prop_tutorials/tutorial_pattern_generation.py -.. [3] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Dahmen D, van Albada SJ, Bolten M, Diesmann M. - Event-based implementation of eligibility propagation (in preparation) +.. [3] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Plesser HE, + Dahmen D, Bolten M, Van Albada SJ*, Diesmann M*. Event-based + implementation of eligibility propagation (in preparation) + """ # pylint: disable=line-too-long # noqa: E501 # %% ########################################################################################################### @@ -100,7 +101,7 @@ # synapse models below. The connections that must be established are numbered 1 to 6. try: - Image(filename="./eprop_supervised_regression_schematic_handwriting.png") + Image(filename="./eprop_supervised_regression_handwriting_bsshslm_2020.png") except Exception: pass @@ -121,9 +122,10 @@ # Define timing of task # ..................... # The task's temporal structure is then defined, once as time steps and once as durations in milliseconds. +# Increasing the number of iterations enhances learning performance. -n_batch = 1 # batch size -n_iter = 5 # number of iterations, 5000 for good convergence +batch_size = 1 # batch size +n_iter = 200 # number of iterations, 5000 to reach convergence as in the figure data_file_name = "chaos_handwriting.txt" # name of file with task data data = np.loadtxt(data_file_name) @@ -134,7 +136,7 @@ steps["sequence"] = len(data) * steps["data_point"] # time steps of one full sequence steps["learning_window"] = steps["sequence"] # time steps of window with non-zero learning signals -steps["task"] = n_iter * n_batch * steps["sequence"] # time steps of task +steps["task"] = n_iter * batch_size * steps["sequence"] # time steps of task steps.update( { @@ -188,38 +190,43 @@ n_rec = 200 # number of recurrent neurons n_out = 2 # number of readout neurons +params_nrn_out = { + "C_m": 1.0, # pF, membrane capacitance - takes effect only if neurons get current input (here not the case) + "E_L": 0.0, # mV, leak / resting membrane potential + "I_e": 0.0, # pA, external current input + "loss": "mean_squared_error", # loss function + "regular_spike_arrival": False, # If True, input spikes arrive at end of time step, if False at beginning + "tau_m": 50.0, # ms, membrane time constant + "V_m": 0.0, # mV, initial value of the membrane voltage +} + tau_m_mean = 30.0 # ms, mean of membrane time constant distribution params_nrn_rec = { + "beta": 1.0, # width scaling of the pseudo-derivative "adapt_tau": 2000.0, # ms, time constant of adaptive threshold - "C_m": 250.0, # pF, membrane capacitance - takes effect only if neurons get current input (here not the case) - "c_reg": 150.0, # firing rate regularization scaling - "E_L": 0.0, # mV, leak / resting membrane potential + "C_m": 250.0, + "c_reg": 150.0, # coefficient of firing rate regularization + "E_L": 0.0, "f_target": 20.0, # spikes/s, target firing rate for firing rate regularization - "gamma": 0.3, # scaling of the pseudo derivative - "I_e": 0.0, # pA, external current input - "regular_spike_arrival": False, # If True, input spikes arrive at end of time step, if False at beginning + "gamma": 0.3, # height scaling of the pseudo-derivative + "I_e": 0.0, + "regular_spike_arrival": False, "surrogate_gradient_function": "piecewise_linear", # surrogate gradient / pseudo-derivative function "t_ref": 0.0, # ms, duration of refractory period - "tau_m": nest.random.normal(mean=tau_m_mean, std=2.0), # ms, membrane time constant - "V_m": 0.0, # mV, initial value of the membrane voltage + "tau_m": nest.random.normal(mean=tau_m_mean, std=2.0), + "V_m": 0.0, "V_th": 0.03, # mV, spike threshold membrane voltage } +# factors from the original pseudo-derivative definition are incorporated into the parameters +params_nrn_rec["gamma"] /= params_nrn_rec["V_th"] +params_nrn_rec["beta"] /= np.abs(params_nrn_rec["V_th"]) # prefactor is inside abs in the original definition + params_nrn_rec["adapt_beta"] = ( 1.7 * (1.0 - np.exp(-1 / params_nrn_rec["adapt_tau"])) / (1.0 - np.exp(-1.0 / tau_m_mean)) ) # prefactor of adaptive threshold -params_nrn_out = { - "C_m": 1.0, - "E_L": 0.0, - "I_e": 0.0, - "loss": "mean_squared_error", # loss function - "regular_spike_arrival": False, - "tau_m": 50.0, - "V_m": 0.0, -} - #################### # Intermediate parrot neurons required between input spike generators and recurrent neurons, @@ -246,7 +253,7 @@ # default, recordings are stored in memory but can also be written to file. n_record = 1 # number of neurons to record dynamic variables from - this script requires n_record >= 1 -n_record_w = 3 # number of senders and targets to record weights from - this script requires n_record_w >=1 +n_record_w = 5 # number of senders and targets to record weights from - this script requires n_record_w >=1 if n_record == 0 or n_record_w == 0: raise ValueError("n_record and n_record_w >= 1 required") @@ -262,6 +269,7 @@ ], # dynamic variables to record "start": duration["offset_gen"] + duration["delay_in_rec"], # start time of recording "stop": duration["offset_gen"] + duration["delay_in_rec"] + duration["task"], # stop time of recording + "label": "multimeter_rec", } params_mm_out = { @@ -269,6 +277,7 @@ "record_from": ["V_m", "readout_signal", "readout_signal_unnorm", "target_signal", "error_signal"], "start": duration["total_offset"], "stop": duration["total_offset"] + duration["task"], + "label": "multimeter_out", } params_wr = { @@ -276,18 +285,27 @@ "targets": nrns_rec[:n_record_w] + nrns_out, # limit targets to subsample weights to record from "start": duration["total_offset"], "stop": duration["total_offset"] + duration["task"], + "label": "weight_recorder", } -params_sr = { - "start": duration["total_offset"], +params_sr_in = { + "start": duration["offset_gen"], + "stop": duration["total_offset"] + duration["task"], + "label": "spike_recorder_in", +} + +params_sr_rec = { + "start": duration["offset_gen"], "stop": duration["total_offset"] + duration["task"], + "label": "spike_recorder_rec", } #################### mm_rec = nest.Create("multimeter", params_mm_rec) mm_out = nest.Create("multimeter", params_mm_out) -sr = nest.Create("spike_recorder", params_sr) +sr_in = nest.Create("spike_recorder", params_sr_in) +sr_rec = nest.Create("spike_recorder", params_sr_rec) wr = nest.Create("weight_recorder", params_wr) nrns_rec_record = nrns_rec[:n_record] @@ -312,7 +330,7 @@ params_common_syn_eprop = { "optimizer": { "type": "adam", # algorithm to optimize the weights - "batch_size": n_batch, + "batch_size": batch_size, "beta_1": 0.9, # exponential decay rate for 1st moment estimate of Adam optimizer "beta_2": 0.999, # exponential decay rate for 2nd moment raw estimate of Adam optimizer "epsilon": 1e-8, # small numerical stabilization constant of Adam optimizer @@ -374,7 +392,8 @@ nest.Connect(nrns_out, nrns_rec, params_conn_all_to_all, params_syn_feedback) # connection 5 nest.Connect(gen_rate_target, nrns_out, params_conn_one_to_one, params_syn_rate_target) # connection 6 -nest.Connect(nrns_in + nrns_rec, sr, params_conn_all_to_all, params_syn_static) +nest.Connect(nrns_in, sr_in, params_conn_all_to_all, params_syn_static) +nest.Connect(nrns_rec, sr_rec, params_conn_all_to_all, params_syn_static) nest.Connect(mm_rec, nrns_rec_record, params_conn_all_to_all, params_syn_static) nest.Connect(mm_out, nrns_out, params_conn_all_to_all, params_syn_static) @@ -430,7 +449,7 @@ params_gen_rate_target.append( { "amplitude_times": np.arange(0.0, duration["task"], duration["step"]) + duration["total_offset"], - "amplitude_values": np.tile(target_signal, n_iter * n_batch), + "amplitude_values": np.tile(target_signal, n_iter * batch_size), } ) @@ -500,7 +519,8 @@ def get_weights(pop_pre, pop_post): events_mm_rec = mm_rec.get("events") events_mm_out = mm_out.get("events") -events_sr = sr.get("events") +events_sr_in = sr_in.get("events") +events_sr_rec = sr_rec.get("events") events_wr = wr.get("events") # %% ########################################################################################################### @@ -519,7 +539,13 @@ def get_weights(pop_pre, pop_post): error = (readout_signal[idc] - target_signal[idc]) ** 2 loss_list.append(0.5 * np.add.reduceat(error, np.arange(0, steps["task"], steps["sequence"]))) -loss = np.sum(loss_list, axis=0) +readout_signal = np.array([readout_signal[senders == i] for i in set(senders)]) +target_signal = np.array([target_signal[senders == i] for i in set(senders)]) + +readout_signal = readout_signal.reshape((n_out, n_iter, batch_size, steps["sequence"])) +target_signal = target_signal.reshape((n_out, n_iter, batch_size, steps["sequence"])) + +loss = 0.5 * np.mean(np.sum((readout_signal - target_signal) ** 2, axis=3), axis=(0, 2)) # %% ########################################################################################################### # Plot results @@ -539,7 +565,6 @@ def get_weights(pop_pre, pop_post): plt.rcParams.update( { - "font.sans-serif": "Arial", "axes.spines.right": False, "axes.spines.top": False, "axes.prop_cycle": cycler(color=[colors["blue"], colors["red"]]), @@ -553,20 +578,11 @@ def get_weights(pop_pre, pop_post): # neurons encode the horizontal and vertical coordinate of the pattern respectively. fig, ax = plt.subplots() +fig.suptitle("Pattern") -ax.plot( - readout_signal[senders == list(set(senders))[0]][-steps["sequence"] :], - -readout_signal[senders == list(set(senders))[1]][-steps["sequence"] :], - c=colors["red"], - label="readout", -) +ax.plot(readout_signal[0, -1, 0, :], -readout_signal[1, -1, 0, :], c=colors["red"], label="readout") -ax.plot( - target_signal[senders == list(set(senders))[0]][-steps["sequence"] :], - -target_signal[senders == list(set(senders))[1]][-steps["sequence"] :], - c=colors["blue"], - label="target", -) +ax.plot(target_signal[0, -1, 0, :], -target_signal[1, -1, 0, :], c=colors["blue"], label="target") ax.set_xlabel(r"$y_0$ and $y^*_0$") ax.set_ylabel(r"$y_1$ and $y^*_1$") @@ -581,6 +597,7 @@ def get_weights(pop_pre, pop_post): # We begin with a plot visualizing the training error of the network: the loss plotted against the iterations. fig, ax = plt.subplots() +fig.suptitle("Training error") ax.plot(range(1, n_iter + 1), loss_list[0], label=r"$E_0$", alpha=0.8, c=colors["blue"], ls="--") ax.plot(range(1, n_iter + 1), loss_list[1], label=r"$E_1$", alpha=0.8, c=colors["blue"], ls="dotted") @@ -590,6 +607,7 @@ def get_weights(pop_pre, pop_post): ax.set_xlim(1, n_iter) ax.xaxis.get_major_locator().set_params(integer=True) ax.legend(bbox_to_anchor=(1.01, 0.5), loc="center left") + fig.tight_layout() # %% ########################################################################################################### @@ -609,11 +627,10 @@ def plot_recordable(ax, events, recordable, ylabel, xlims): ax.set_ylim(np.min(events[recordable]) - margin, np.max(events[recordable]) + margin) -def plot_spikes(ax, events, nrns, ylabel, xlims): +def plot_spikes(ax, events, ylabel, xlims): idc_times = (events["times"] > xlims[0]) & (events["times"] < xlims[1]) - idc_sender = np.isin(events["senders"][idc_times], nrns.tolist()) - senders_subset = events["senders"][idc_times][idc_sender] - times_subset = events["times"][idc_times][idc_sender] + senders_subset = events["senders"][idc_times] + times_subset = events["times"][idc_times] ax.scatter(times_subset, senders_subset, s=0.1) ax.set_ylabel(ylabel) @@ -621,23 +638,25 @@ def plot_spikes(ax, events, nrns, ylabel, xlims): ax.set_ylim(np.min(senders_subset) - margin, np.max(senders_subset) + margin) -for xlims in [(0, steps["sequence"]), (steps["task"] - steps["sequence"], steps["task"])]: - fig, axs = plt.subplots(12, 1, sharex=True, figsize=(8, 12), gridspec_kw={"hspace": 0.4, "left": 0.2}) - - plot_spikes(axs[0], events_sr, nrns_in, r"$z_i$" + "\n", xlims) - plot_spikes(axs[1], events_sr, nrns_rec, r"$z_j$" + "\n", xlims) +for title, xlims in zip( + ["Dynamic variables before training", "Dynamic variables after training"], + [(0, steps["sequence"]), (steps["task"] - steps["sequence"], steps["task"])], +): + fig, axs = plt.subplots(10, 1, sharex=True, figsize=(8, 12), gridspec_kw={"hspace": 0.4, "left": 0.2}) + fig.suptitle(title) - plot_spikes(axs[3], events_sr, nrns_rec, r"$z_j$" + "\n", xlims) + plot_spikes(axs[0], events_sr_in, r"$z_i$" + "\n", xlims) + plot_spikes(axs[1], events_sr_rec, r"$z_j$" + "\n", xlims) - plot_recordable(axs[4], events_mm_rec, "V_m", r"$v_j$" + "\n(mV)", xlims) - plot_recordable(axs[5], events_mm_rec, "surrogate_gradient", r"$\psi_j$" + "\n", xlims) - plot_recordable(axs[6], events_mm_rec, "V_th_adapt", r"$A_j$" + "\n(mV)", xlims) - plot_recordable(axs[7], events_mm_rec, "learning_signal", r"$L_j$" + "\n(pA)", xlims) + plot_recordable(axs[2], events_mm_rec, "V_m", r"$v_j$" + "\n(mV)", xlims) + plot_recordable(axs[3], events_mm_rec, "surrogate_gradient", r"$\psi_j$" + "\n", xlims) + plot_recordable(axs[4], events_mm_rec, "V_th_adapt", r"$A_j$" + "\n(mV)", xlims) + plot_recordable(axs[5], events_mm_rec, "learning_signal", r"$L_j$" + "\n(pA)", xlims) - plot_recordable(axs[8], events_mm_out, "V_m", r"$v_k$" + "\n(mV)", xlims) - plot_recordable(axs[9], events_mm_out, "target_signal", r"$y^*_k$" + "\n", xlims) - plot_recordable(axs[10], events_mm_out, "readout_signal", r"$y_k$" + "\n", xlims) - plot_recordable(axs[11], events_mm_out, "error_signal", r"$y_k-y^*_k$" + "\n", xlims) + plot_recordable(axs[6], events_mm_out, "V_m", r"$v_k$" + "\n(mV)", xlims) + plot_recordable(axs[7], events_mm_out, "target_signal", r"$y^*_k$" + "\n", xlims) + plot_recordable(axs[8], events_mm_out, "readout_signal", r"$y_k$" + "\n", xlims) + plot_recordable(axs[9], events_mm_out, "error_signal", r"$y_k-y^*_k$" + "\n", xlims) axs[-1].set_xlabel(r"$t$ (ms)") axs[-1].set_xlim(*xlims) @@ -653,7 +672,10 @@ def plot_spikes(ax, events, nrns, ylabel, xlims): # the first time step and we add the initial weights manually. -def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabel): +def plot_weight_time_course(ax, events, nrns_weight_record, label, ylabel): + sender_label, target_label = label.split("_") + nrns_senders = nrns_weight_record[sender_label] + nrns_targets = nrns_weight_record[target_label] for sender in nrns_senders.tolist(): for target in nrns_targets.tolist(): idc_syn = (events["senders"] == sender) & (events["targets"] == target) @@ -670,12 +692,17 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4)) +fig.suptitle("Weight time courses") -plot_weight_time_course(axs[0], events_wr, nrns_in[:n_record_w], nrns_rec[:n_record_w], "in_rec", r"$W_\text{in}$ (pA)") -plot_weight_time_course( - axs[1], events_wr, nrns_rec[:n_record_w], nrns_rec[:n_record_w], "rec_rec", r"$W_\text{rec}$ (pA)" -) -plot_weight_time_course(axs[2], events_wr, nrns_rec[:n_record_w], nrns_out, "rec_out", r"$W_\text{out}$ (pA)") +nrns_weight_record = { + "in": nrns_in[:n_record_w], + "rec": nrns_rec[:n_record_w], + "out": nrns_out, +} + +plot_weight_time_course(axs[0], events_wr, nrns_weight_record, "in_rec", r"$W_\text{in}$ (pA)") +plot_weight_time_course(axs[1], events_wr, nrns_weight_record, "rec_rec", r"$W_\text{rec}$ (pA)") +plot_weight_time_course(axs[2], events_wr, nrns_weight_record, "rec_out", r"$W_\text{out}$ (pA)") axs[-1].set_xlabel(r"$t$ (ms)") axs[-1].set_xlim(0, steps["task"]) @@ -695,6 +722,7 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe ) fig, axs = plt.subplots(3, 2, sharex="col", sharey="row") +fig.suptitle("Weight matrices") all_w_extrema = [] @@ -717,8 +745,8 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe axs[2, 0].set_ylabel("readout\nneurons") fig.align_ylabels(axs[:, 0]) -axs[0, 0].text(0.5, 1.1, "pre-training", transform=axs[0, 0].transAxes, ha="center") -axs[0, 1].text(0.5, 1.1, "post-training", transform=axs[0, 1].transAxes, ha="center") +axs[0, 0].text(0.5, 1.1, "before training", transform=axs[0, 0].transAxes, ha="center") +axs[0, 1].text(0.5, 1.1, "after training", transform=axs[0, 1].transAxes, ha="center") axs[2, 0].yaxis.get_major_locator().set_params(integer=True) diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_regression_lemniscate_bsshslm_2020.png b/pynest/examples/eprop_plasticity/eprop_supervised_regression_lemniscate_bsshslm_2020.png new file mode 100644 index 0000000000..a62a947b9e Binary files /dev/null and b/pynest/examples/eprop_plasticity/eprop_supervised_regression_lemniscate_bsshslm_2020.png differ diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_regression_infinite-loop.py b/pynest/examples/eprop_plasticity/eprop_supervised_regression_lemniscate_bsshslm_2020.py similarity index 84% rename from pynest/examples/eprop_plasticity/eprop_supervised_regression_infinite-loop.py rename to pynest/examples/eprop_plasticity/eprop_supervised_regression_lemniscate_bsshslm_2020.py index 9cd9b07eb6..90e6000e06 100644 --- a/pynest/examples/eprop_plasticity/eprop_supervised_regression_infinite-loop.py +++ b/pynest/examples/eprop_plasticity/eprop_supervised_regression_lemniscate_bsshslm_2020.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # -# eprop_supervised_regression_infinite-loop.py +# eprop_supervised_regression_lemniscate_bsshslm_2020.py # # This file is part of NEST. # @@ -20,10 +20,10 @@ # along with NEST. If not, see . r""" -Tutorial on learning to generate an infinite loop with e-prop -------------------------------------------------------------- +Tutorial on learning to generate a lemniscate with e-prop after Bellec et al. (2020) +------------------------------------------------------------------------------------ -Training a regression model using supervised e-prop plasticity to generate an infinite loop +Training a regression model using supervised e-prop plasticity to generate a lemniscate Description ~~~~~~~~~~~ @@ -34,14 +34,13 @@ This type of learning is demonstrated at the proof-of-concept task in [1]_. We based this script on their TensorFlow script given in [2]_ and changed the task as well as the parameters slightly. - In this task, the network learns to generate an arbitrary N-dimensional temporal pattern. Here, the network -learns to reproduce with its overall spiking activity a two-dimensional, roughly two-second-long target signal -which encode the x and y coordinates of an infinite-loop. +learns to reproduce with its overall spiking activity a two-dimensional, roughly one-second-long target signal +which encode the x and y coordinates of a lemniscate. -.. image:: eprop_supervised_regression_schematic_infinite-loop.png +.. image:: eprop_supervised_regression_lemniscate_bsshslm_2020.png :width: 70 % - :alt: See Figure 1 below. + :alt: Schematic of network architecture. Same as Figure 1 in the code. :align: center Learning in the neural network model is achieved by optimizing the connection weights with e-prop plasticity. @@ -66,8 +65,10 @@ .. [2] https://github.com/IGITUGraz/eligibility_propagation/blob/master/Figure_3_and_S7_e_prop_tutorials/tutorial_pattern_generation.py -.. [3] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Dahmen D, van Albada SJ, Bolten M, Diesmann M. - Event-based implementation of eligibility propagation (in preparation) +.. [3] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Plesser HE, + Dahmen D, Bolten M, Van Albada SJ*, Diesmann M*. Event-based + implementation of eligibility propagation (in preparation) + """ # pylint: disable=line-too-long # noqa: E501 # %% ########################################################################################################### @@ -90,7 +91,7 @@ # synapse models below. The connections that must be established are numbered 1 to 6. try: - Image(filename="./eprop_supervised_regression_schematic_infinite-loop.png") + Image(filename="./eprop_supervised_regression_lemniscate_bsshslm_2020.png") except Exception: pass @@ -111,16 +112,17 @@ # Define timing of task # ..................... # The task's temporal structure is then defined, once as time steps and once as durations in milliseconds. +# Increasing the number of iterations enhances learning performance. -n_batch = 1 # batch size -n_iter = 5 # number of iterations, 5000 for good convergence +batch_size = 1 # batch size +n_iter = 200 # number of iterations, 5000 to reach convergence as in the figure steps = { "sequence": 1258, # time steps of one full sequence } steps["learning_window"] = steps["sequence"] # time steps of window with non-zero learning signals -steps["task"] = n_iter * n_batch * steps["sequence"] # time steps of task +steps["task"] = n_iter * batch_size * steps["sequence"] # time steps of task steps.update( { @@ -174,38 +176,43 @@ n_rec = 200 # number of recurrent neurons n_out = 2 # number of readout neurons +params_nrn_out = { + "C_m": 1.0, # pF, membrane capacitance - takes effect only if neurons get current input (here not the case) + "E_L": 0.0, # mV, leak / resting membrane potential + "I_e": 0.0, # pA, external current input + "loss": "mean_squared_error", # loss function + "regular_spike_arrival": False, # If True, input spikes arrive at end of time step, if False at beginning + "tau_m": 50.0, # ms, membrane time constant + "V_m": 0.0, # mV, initial value of the membrane voltage +} + tau_m_mean = 30.0 # ms, mean of membrane time constant distribution params_nrn_rec = { + "beta": 1.0, # width scaling of the pseudo-derivative "adapt_tau": 2000.0, # ms, time constant of adaptive threshold - "C_m": 250.0, # pF, membrane capacitance - takes effect only if neurons get current input (here not the case) - "c_reg": 150.0, # firing rate regularization scaling - "E_L": 0.0, # mV, leak / resting membrane potential + "C_m": 250.0, + "c_reg": 150.0, # coefficient of firing rate regularization + "E_L": 0.0, "f_target": 20.0, # spikes/s, target firing rate for firing rate regularization - "gamma": 0.3, # scaling of the pseudo derivative - "I_e": 0.0, # pA, external current input - "regular_spike_arrival": False, # If True, input spikes arrive at end of time step, if False at beginning + "gamma": 0.3, # height scaling of the pseudo-derivative + "I_e": 0.0, + "regular_spike_arrival": False, "surrogate_gradient_function": "piecewise_linear", # surrogate gradient / pseudo-derivative function "t_ref": 0.0, # ms, duration of refractory period - "tau_m": nest.random.normal(mean=tau_m_mean, std=2.0), # ms, membrane time constant - "V_m": 0.0, # mV, initial value of the membrane voltage + "tau_m": nest.random.normal(mean=tau_m_mean, std=2.0), + "V_m": 0.0, "V_th": 0.03, # mV, spike threshold membrane voltage } +# factors from the original pseudo-derivative definition are incorporated into the parameters +params_nrn_rec["gamma"] /= params_nrn_rec["V_th"] +params_nrn_rec["beta"] /= np.abs(params_nrn_rec["V_th"]) # prefactor is inside abs in the original definition + params_nrn_rec["adapt_beta"] = ( 1.7 * (1.0 - np.exp(-1 / params_nrn_rec["adapt_tau"])) / (1.0 - np.exp(-1.0 / tau_m_mean)) ) # prefactor of adaptive threshold -params_nrn_out = { - "C_m": 1.0, - "E_L": 0.0, - "I_e": 0.0, - "loss": "mean_squared_error", # loss function - "regular_spike_arrival": False, - "tau_m": 50.0, - "V_m": 0.0, -} - #################### # Intermediate parrot neurons required between input spike generators and recurrent neurons, @@ -232,7 +239,7 @@ # default, recordings are stored in memory but can also be written to file. n_record = 1 # number of neurons to record dynamic variables from - this script requires n_record >= 1 -n_record_w = 3 # number of senders and targets to record weights from - this script requires n_record_w >=1 +n_record_w = 5 # number of senders and targets to record weights from - this script requires n_record_w >=1 if n_record == 0 or n_record_w == 0: raise ValueError("n_record and n_record_w >= 1 required") @@ -248,6 +255,7 @@ ], # dynamic variables to record "start": duration["offset_gen"] + duration["delay_in_rec"], # start time of recording "stop": duration["offset_gen"] + duration["delay_in_rec"] + duration["task"], # stop time of recording + "label": "multimeter_rec", } params_mm_out = { @@ -255,6 +263,7 @@ "record_from": ["V_m", "readout_signal", "readout_signal_unnorm", "target_signal", "error_signal"], "start": duration["total_offset"], "stop": duration["total_offset"] + duration["task"], + "label": "multimeter_out", } params_wr = { @@ -264,16 +273,24 @@ "stop": duration["total_offset"] + duration["task"], } -params_sr = { - "start": duration["total_offset"], +params_sr_in = { + "start": duration["offset_gen"], + "stop": duration["total_offset"] + duration["task"], + "label": "spike_recorder_in", +} + +params_sr_rec = { + "start": duration["offset_gen"], "stop": duration["total_offset"] + duration["task"], + "label": "spike_recorder_rec", } #################### mm_rec = nest.Create("multimeter", params_mm_rec) mm_out = nest.Create("multimeter", params_mm_out) -sr = nest.Create("spike_recorder", params_sr) +sr_in = nest.Create("spike_recorder", params_sr_in) +sr_rec = nest.Create("spike_recorder", params_sr_rec) wr = nest.Create("weight_recorder", params_wr) nrns_rec_record = nrns_rec[:n_record] @@ -298,7 +315,7 @@ params_common_syn_eprop = { "optimizer": { "type": "adam", # algorithm to optimize the weights - "batch_size": n_batch, + "batch_size": batch_size, "beta_1": 0.9, # exponential decay rate for 1st moment estimate of Adam optimizer "beta_2": 0.999, # exponential decay rate for 2nd moment raw estimate of Adam optimizer "epsilon": 1e-8, # small numerical stabilization constant of Adam optimizer @@ -325,7 +342,6 @@ params_syn_out = params_syn_base.copy() params_syn_out["weight"] = weights_rec_out - params_syn_feedback = { "synapse_model": "eprop_learning_signal_connection_bsshslm_2020", "delay": duration["step"], @@ -361,7 +377,8 @@ nest.Connect(nrns_out, nrns_rec, params_conn_all_to_all, params_syn_feedback) # connection 5 nest.Connect(gen_rate_target, nrns_out, params_conn_one_to_one, params_syn_rate_target) # connection 6 -nest.Connect(nrns_in + nrns_rec, sr, params_conn_all_to_all, params_syn_static) +nest.Connect(nrns_in, sr_in, params_conn_all_to_all, params_syn_static) +nest.Connect(nrns_rec, sr_rec, params_conn_all_to_all, params_syn_static) nest.Connect(mm_rec, nrns_rec_record, params_conn_all_to_all, params_syn_static) nest.Connect(mm_out, nrns_out, params_conn_all_to_all, params_syn_static) @@ -398,7 +415,7 @@ # %% ########################################################################################################### # Create output # ~~~~~~~~~~~~~ -# Then, we load the x and y values of an image of the word "chaos" written by hand and construct a roughly +# Then, we load the x and y values of an image of a lemniscate and construct a roughly # one-second long target signal from it. This signal, like the input, is repeated for all iterations and fed # into the rate generator that was previously created. @@ -413,7 +430,7 @@ params_gen_rate_target.append( { "amplitude_times": np.arange(0.0, duration["task"], duration["step"]) + duration["total_offset"], - "amplitude_values": np.tile(target_signal, n_iter * n_batch), + "amplitude_values": np.tile(target_signal, n_iter * batch_size), } ) @@ -483,7 +500,8 @@ def get_weights(pop_pre, pop_post): events_mm_rec = mm_rec.get("events") events_mm_out = mm_out.get("events") -events_sr = sr.get("events") +events_sr_in = sr_in.get("events") +events_sr_rec = sr_rec.get("events") events_wr = wr.get("events") # %% ########################################################################################################### @@ -502,8 +520,13 @@ def get_weights(pop_pre, pop_post): error = (readout_signal[idc] - target_signal[idc]) ** 2 loss_list.append(0.5 * np.add.reduceat(error, np.arange(0, steps["task"], steps["sequence"]))) -loss = np.sum(loss_list, axis=0) +readout_signal = np.array([readout_signal[senders == i] for i in set(senders)]) +target_signal = np.array([target_signal[senders == i] for i in set(senders)]) + +readout_signal = readout_signal.reshape((n_out, n_iter, batch_size, steps["sequence"])) +target_signal = target_signal.reshape((n_out, n_iter, batch_size, steps["sequence"])) +loss = 0.5 * np.mean(np.sum((readout_signal - target_signal) ** 2, axis=3), axis=(0, 2)) # %% ########################################################################################################### # Plot results @@ -523,7 +546,6 @@ def get_weights(pop_pre, pop_post): plt.rcParams.update( { - "font.sans-serif": "Arial", "axes.spines.right": False, "axes.spines.top": False, "axes.prop_cycle": cycler(color=[colors["blue"], colors["red"]]), @@ -537,20 +559,11 @@ def get_weights(pop_pre, pop_post): # neurons encode the horizontal and vertical coordinate of the pattern respectively. fig, ax = plt.subplots() +fig.suptitle("Pattern") -ax.plot( - readout_signal[senders == list(set(senders))[0]][-steps["sequence"] :], - -readout_signal[senders == list(set(senders))[1]][-steps["sequence"] :], - c=colors["red"], - label="readout", -) +ax.plot(readout_signal[0, -1, 0, :], -readout_signal[1, -1, 0, :], c=colors["red"], label="readout") -ax.plot( - target_signal[senders == list(set(senders))[0]][-steps["sequence"] :], - -target_signal[senders == list(set(senders))[1]][-steps["sequence"] :], - c=colors["blue"], - label="target", -) +ax.plot(target_signal[0, -1, 0, :], -target_signal[1, -1, 0, :], c=colors["blue"], label="target") ax.set_xlabel(r"$y_0$ and $y^*_0$") ax.set_ylabel(r"$y_1$ and $y^*_1$") @@ -565,6 +578,7 @@ def get_weights(pop_pre, pop_post): # We begin with a plot visualizing the training error of the network: the loss plotted against the iterations. fig, ax = plt.subplots() +fig.suptitle("Training error") ax.plot(range(1, n_iter + 1), loss_list[0], label=r"$E_0$", alpha=0.8, c=colors["blue"], ls="--") ax.plot(range(1, n_iter + 1), loss_list[1], label=r"$E_1$", alpha=0.8, c=colors["blue"], ls="dotted") @@ -574,6 +588,7 @@ def get_weights(pop_pre, pop_post): ax.set_xlim(1, n_iter) ax.xaxis.get_major_locator().set_params(integer=True) ax.legend(bbox_to_anchor=(1.01, 0.5), loc="center left") + fig.tight_layout() # %% ########################################################################################################### @@ -593,11 +608,10 @@ def plot_recordable(ax, events, recordable, ylabel, xlims): ax.set_ylim(np.min(events[recordable]) - margin, np.max(events[recordable]) + margin) -def plot_spikes(ax, events, nrns, ylabel, xlims): +def plot_spikes(ax, events, ylabel, xlims): idc_times = (events["times"] > xlims[0]) & (events["times"] < xlims[1]) - idc_sender = np.isin(events["senders"][idc_times], nrns.tolist()) - senders_subset = events["senders"][idc_times][idc_sender] - times_subset = events["times"][idc_times][idc_sender] + senders_subset = events["senders"][idc_times] + times_subset = events["times"][idc_times] ax.scatter(times_subset, senders_subset, s=0.1) ax.set_ylabel(ylabel) @@ -605,23 +619,25 @@ def plot_spikes(ax, events, nrns, ylabel, xlims): ax.set_ylim(np.min(senders_subset) - margin, np.max(senders_subset) + margin) -for xlims in [(0, steps["sequence"]), (steps["task"] - steps["sequence"], steps["task"])]: - fig, axs = plt.subplots(12, 1, sharex=True, figsize=(8, 12), gridspec_kw={"hspace": 0.4, "left": 0.2}) - - plot_spikes(axs[0], events_sr, nrns_in, r"$z_i$" + "\n", xlims) - plot_spikes(axs[1], events_sr, nrns_rec, r"$z_j$" + "\n", xlims) +for title, xlims in zip( + ["Dynamic variables before training", "Dynamic variables after training"], + [(0, steps["sequence"]), (steps["task"] - steps["sequence"], steps["task"])], +): + fig, axs = plt.subplots(10, 1, sharex=True, figsize=(8, 12), gridspec_kw={"hspace": 0.4, "left": 0.2}) + fig.suptitle(title) - plot_spikes(axs[3], events_sr, nrns_rec, r"$z_j$" + "\n", xlims) + plot_spikes(axs[0], events_sr_in, r"$z_i$" + "\n", xlims) + plot_spikes(axs[1], events_sr_rec, r"$z_j$" + "\n", xlims) - plot_recordable(axs[4], events_mm_rec, "V_m", r"$v_j$" + "\n(mV)", xlims) - plot_recordable(axs[5], events_mm_rec, "surrogate_gradient", r"$\psi_j$" + "\n", xlims) - plot_recordable(axs[6], events_mm_rec, "V_th_adapt", r"$A_j$" + "\n(mV)", xlims) - plot_recordable(axs[7], events_mm_rec, "learning_signal", r"$L_j$" + "\n(pA)", xlims) + plot_recordable(axs[2], events_mm_rec, "V_m", r"$v_j$" + "\n(mV)", xlims) + plot_recordable(axs[3], events_mm_rec, "surrogate_gradient", r"$\psi_j$" + "\n", xlims) + plot_recordable(axs[4], events_mm_rec, "V_th_adapt", r"$A_j$" + "\n(mV)", xlims) + plot_recordable(axs[5], events_mm_rec, "learning_signal", r"$L_j$" + "\n(pA)", xlims) - plot_recordable(axs[8], events_mm_out, "V_m", r"$v_k$" + "\n(mV)", xlims) - plot_recordable(axs[9], events_mm_out, "target_signal", r"$y^*_k$" + "\n", xlims) - plot_recordable(axs[10], events_mm_out, "readout_signal", r"$y_k$" + "\n", xlims) - plot_recordable(axs[11], events_mm_out, "error_signal", r"$y_k-y^*_k$" + "\n", xlims) + plot_recordable(axs[6], events_mm_out, "V_m", r"$v_k$" + "\n(mV)", xlims) + plot_recordable(axs[7], events_mm_out, "target_signal", r"$y^*_k$" + "\n", xlims) + plot_recordable(axs[8], events_mm_out, "readout_signal", r"$y_k$" + "\n", xlims) + plot_recordable(axs[9], events_mm_out, "error_signal", r"$y_k-y^*_k$" + "\n", xlims) axs[-1].set_xlabel(r"$t$ (ms)") axs[-1].set_xlim(*xlims) @@ -637,7 +653,10 @@ def plot_spikes(ax, events, nrns, ylabel, xlims): # the first time step and we add the initial weights manually. -def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabel): +def plot_weight_time_course(ax, events, nrns_weight_record, label, ylabel): + sender_label, target_label = label.split("_") + nrns_senders = nrns_weight_record[sender_label] + nrns_targets = nrns_weight_record[target_label] for sender in nrns_senders.tolist(): for target in nrns_targets.tolist(): idc_syn = (events["senders"] == sender) & (events["targets"] == target) @@ -654,12 +673,17 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4)) +fig.suptitle("Weight time courses") -plot_weight_time_course(axs[0], events_wr, nrns_in[:n_record_w], nrns_rec[:n_record_w], "in_rec", r"$W_\text{in}$ (pA)") -plot_weight_time_course( - axs[1], events_wr, nrns_rec[:n_record_w], nrns_rec[:n_record_w], "rec_rec", r"$W_\text{rec}$ (pA)" -) -plot_weight_time_course(axs[2], events_wr, nrns_rec[:n_record_w], nrns_out, "rec_out", r"$W_\text{out}$ (pA)") +nrns_weight_record = { + "in": nrns_in[:n_record_w], + "rec": nrns_rec[:n_record_w], + "out": nrns_out, +} + +plot_weight_time_course(axs[0], events_wr, nrns_weight_record, "in_rec", r"$W_\text{in}$ (pA)") +plot_weight_time_course(axs[1], events_wr, nrns_weight_record, "rec_rec", r"$W_\text{rec}$ (pA)") +plot_weight_time_course(axs[2], events_wr, nrns_weight_record, "rec_out", r"$W_\text{out}$ (pA)") axs[-1].set_xlabel(r"$t$ (ms)") axs[-1].set_xlim(0, steps["task"]) @@ -679,6 +703,7 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe ) fig, axs = plt.subplots(3, 2, sharex="col", sharey="row") +fig.suptitle("Weight matrices") all_w_extrema = [] @@ -701,8 +726,8 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe axs[2, 0].set_ylabel("readout\nneurons") fig.align_ylabels(axs[:, 0]) -axs[0, 0].text(0.5, 1.1, "pre-training", transform=axs[0, 0].transAxes, ha="center") -axs[0, 1].text(0.5, 1.1, "post-training", transform=axs[0, 1].transAxes, ha="center") +axs[0, 0].text(0.5, 1.1, "before training", transform=axs[0, 0].transAxes, ha="center") +axs[0, 1].text(0.5, 1.1, "after training", transform=axs[0, 1].transAxes, ha="center") axs[2, 0].yaxis.get_major_locator().set_params(integer=True) diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_regression_schematic_handwriting.png b/pynest/examples/eprop_plasticity/eprop_supervised_regression_schematic_handwriting.png deleted file mode 100644 index 84ce96ed5e..0000000000 Binary files a/pynest/examples/eprop_plasticity/eprop_supervised_regression_schematic_handwriting.png and /dev/null differ diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_regression_schematic_infinite-loop.png b/pynest/examples/eprop_plasticity/eprop_supervised_regression_schematic_infinite-loop.png deleted file mode 100644 index 445510390a..0000000000 Binary files a/pynest/examples/eprop_plasticity/eprop_supervised_regression_schematic_infinite-loop.png and /dev/null differ diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_regression_schematic_sine-waves.png b/pynest/examples/eprop_plasticity/eprop_supervised_regression_schematic_sine-waves.png deleted file mode 100644 index 89e9d839fe..0000000000 Binary files a/pynest/examples/eprop_plasticity/eprop_supervised_regression_schematic_sine-waves.png and /dev/null differ diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves.png b/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves.png new file mode 100644 index 0000000000..d3b8d6b1de Binary files /dev/null and b/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves.png differ diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves.py b/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves.py index c74a69cf36..0659fff3c2 100644 --- a/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves.py +++ b/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves.py @@ -29,7 +29,8 @@ ~~~~~~~~~~~ This script demonstrates supervised learning of a regression task with a recurrent spiking neural network that -is equipped with the eligibility propagation (e-prop) plasticity mechanism by Bellec et al. [1]_. +is equipped with the eligibility propagation (e-prop) plasticity mechanism by Bellec et al. [1]_ with +additional biological features described in [3]_. This type of learning is demonstrated at the proof-of-concept task in [1]_. We based this script on their TensorFlow script given in [2]_. @@ -38,9 +39,9 @@ network learns to reproduce with its overall spiking activity a one-dimensional, one-second-long target signal which is a superposition of four sine waves of different amplitudes, phases, and periods. -.. image:: eprop_supervised_regression_schematic_sine-waves.png +.. image:: eprop_supervised_regression_sine-waves.png :width: 70 % - :alt: See Figure 1 below. + :alt: Schematic of network architecture. Same as Figure 1 in the code. :align: center Learning in the neural network model is achieved by optimizing the connection weights with e-prop plasticity. @@ -62,8 +63,10 @@ .. [2] https://github.com/IGITUGraz/eligibility_propagation/blob/master/Figure_3_and_S7_e_prop_tutorials/tutorial_pattern_generation.py -.. [3] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Dahmen D, van Albada SJ, Bolten M, Diesmann M. - Event-based implementation of eligibility propagation (in preparation) +.. [3] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Plesser HE, + Dahmen D, Bolten M, Van Albada SJ*, Diesmann M*. Event-based + implementation of eligibility propagation (in preparation) + """ # pylint: disable=line-too-long # noqa: E501 # %% ########################################################################################################### @@ -86,7 +89,7 @@ # synapse models below. The connections that must be established are numbered 1 to 6. try: - Image(filename="./eprop_supervised_regression_schematic_sine-waves.png") + Image(filename="./eprop_supervised_regression_sine-waves.png") except Exception: pass @@ -107,28 +110,31 @@ # Define timing of task # ..................... # The task's temporal structure is then defined, once as time steps and once as durations in milliseconds. +# Even though each sample is processed independently during training, we aggregate predictions and true +# labels across a group of samples during the evaluation phase. The number of samples in this group is +# determined by the `group_size` parameter. This data is then used to assess the neural network's +# performance metrics, such as average accuracy and mean error. Increasing the number of iterations enhances +# learning performance. -n_batch = 1 # batch size, 1 in reference [2] -n_iter = 5 # number of iterations, 2000 in reference [2] +group_size = 1 # number of instances over which to evaluate the learning performance +n_iter = 200 # number of iterations, 2000 in reference [2] steps = { "sequence": 1000, # time steps of one full sequence } steps["learning_window"] = steps["sequence"] # time steps of window with non-zero learning signals -steps["task"] = n_iter * n_batch * steps["sequence"] # time steps of task +steps["task"] = n_iter * group_size * steps["sequence"] # time steps of task steps.update( { "offset_gen": 1, # offset since generator signals start from time step 1 "delay_in_rec": 1, # connection delay between input and recurrent neurons - "delay_rec_out": 1, # connection delay between recurrent and output neurons - "delay_out_norm": 1, # connection delay between output neurons for normalization - "extension_sim": 1, # extra time step to close right-open simulation time interval in Simulate() + "extension_sim": 3, # extra time step to close right-open simulation time interval in Simulate() } ) -steps["delays"] = steps["delay_in_rec"] + steps["delay_rec_out"] + steps["delay_out_norm"] # time steps of delays +steps["delays"] = steps["delay_in_rec"] # time steps of delays steps["total_offset"] = steps["offset_gen"] + steps["delays"] # time steps of total offset @@ -142,12 +148,9 @@ # Set up simulation # ................. # As last step of the setup, we reset the NEST kernel to remove all existing NEST simulation settings and -# objects and set some NEST kernel parameters, some of which are e-prop-related. +# objects and set some NEST kernel parameters. params_setup = { - "eprop_learning_window": duration["learning_window"], - "eprop_reset_neurons_on_update": True, # if True, reset dynamic variables at start of each update interval - "eprop_update_interval": duration["sequence"], # ms, time interval for updating the synaptic weights "print_time": False, # if True, print time progress bar during simulation, set False if run as code cell "resolution": duration["step"], "total_num_virtual_procs": 1, # number of virtual processes, set in case of distributed computing @@ -169,31 +172,43 @@ n_rec = 100 # number of recurrent neurons n_out = 1 # number of readout neurons -params_nrn_rec = { +model_nrn_rec = "eprop_iaf" + +params_nrn_out = { "C_m": 1.0, # pF, membrane capacitance - takes effect only if neurons get current input (here not the case) - "c_reg": 300.0, # firing rate regularization scaling "E_L": 0.0, # mV, leak / resting membrane potential - "f_target": 10.0, # spikes/s, target firing rate for firing rate regularization - "gamma": 0.3, # scaling of the pseudo derivative + "eprop_isi_trace_cutoff": 100, # cutoff of integration of eprop trace between spikes "I_e": 0.0, # pA, external current input "regular_spike_arrival": False, # If True, input spikes arrive at end of time step, if False at beginning - "surrogate_gradient_function": "piecewise_linear", # surrogate gradient / pseudo-derivative function - "t_ref": 0.0, # ms, duration of refractory period "tau_m": 30.0, # ms, membrane time constant "V_m": 0.0, # mV, initial value of the membrane voltage - "V_th": 0.03, # mV, spike threshold membrane voltage } -params_nrn_out = { +params_nrn_rec = { + "beta": 33.3, # width scaling of the pseudo-derivative "C_m": 1.0, + "c_reg": 300.0 / duration["sequence"], # coefficient of firing rate regularization "E_L": 0.0, + "eprop_isi_trace_cutoff": 100, + "f_target": 10.0, # spikes/s, target firing rate for firing rate regularization + "gamma": 10.0, # height scaling of the pseudo-derivative "I_e": 0.0, - "loss": "mean_squared_error", # loss function + "kappa": 0.97, # low-pass filter of the eligibility trace + "kappa_reg": 0.97, # low-pass filter of the firing rate for regularization "regular_spike_arrival": False, + "surrogate_gradient_function": "piecewise_linear", # surrogate gradient / pseudo-derivative function + "t_ref": 0.0, # ms, duration of refractory period "tau_m": 30.0, "V_m": 0.0, + "V_th": 0.03, # mV, spike threshold membrane voltage } +if model_nrn_rec in ["eprop_iaf_psc_delta", "eprop_iaf_psc_delta_adapt"]: + del params_nrn_rec["regular_spike_arrival"] + params_nrn_rec["V_reset"] = -0.5 # mV, reset membrane voltage + params_nrn_rec["c_reg"] = 2.0 / duration["sequence"] + params_nrn_rec["V_th"] = 0.5 + #################### # Intermediate parrot neurons required between input spike generators and recurrent neurons, @@ -202,13 +217,10 @@ gen_spk_in = nest.Create("spike_generator", n_in) nrns_in = nest.Create("parrot_neuron", n_in) -# The suffix _bsshslm_2020 follows the NEST convention to indicate in the model name the paper -# that introduced it by the first letter of the authors' last names and the publication year. - -nrns_rec = nest.Create("eprop_iaf_bsshslm_2020", n_rec, params_nrn_rec) -nrns_out = nest.Create("eprop_readout_bsshslm_2020", n_out, params_nrn_out) +nrns_rec = nest.Create(model_nrn_rec, n_rec, params_nrn_rec) +nrns_out = nest.Create("eprop_readout", n_out, params_nrn_out) gen_rate_target = nest.Create("step_rate_generator", n_out) - +gen_learning_window = nest.Create("step_rate_generator") # %% ########################################################################################################### # Create recorders @@ -220,7 +232,7 @@ # default, recordings are stored in memory but can also be written to file. n_record = 1 # number of neurons to record dynamic variables from - this script requires n_record >= 1 -n_record_w = 3 # number of senders and targets to record weights from - this script requires n_record_w >=1 +n_record_w = 5 # number of senders and targets to record weights from - this script requires n_record_w >=1 if n_record == 0 or n_record_w == 0: raise ValueError("n_record and n_record_w >= 1 required") @@ -230,13 +242,15 @@ "record_from": ["V_m", "surrogate_gradient", "learning_signal"], # dynamic variables to record "start": duration["offset_gen"] + duration["delay_in_rec"], # start time of recording "stop": duration["offset_gen"] + duration["delay_in_rec"] + duration["task"], # stop time of recording + "label": "multimeter_rec", } params_mm_out = { "interval": duration["step"], - "record_from": ["V_m", "readout_signal", "readout_signal_unnorm", "target_signal", "error_signal"], + "record_from": ["V_m", "readout_signal", "target_signal", "error_signal"], "start": duration["total_offset"], "stop": duration["total_offset"] + duration["task"], + "label": "multimeter_out", } params_wr = { @@ -244,18 +258,27 @@ "targets": nrns_rec[:n_record_w] + nrns_out, # limit targets to subsample weights to record from "start": duration["total_offset"], "stop": duration["total_offset"] + duration["task"], + "label": "weight_recorder", } -params_sr = { - "start": duration["total_offset"], +params_sr_in = { + "start": duration["offset_gen"], "stop": duration["total_offset"] + duration["task"], + "label": "spike_recorder_in", +} + +params_sr_rec = { + "start": duration["offset_gen"], + "stop": duration["total_offset"] + duration["task"], + "label": "spike_recorder_rec", } #################### mm_rec = nest.Create("multimeter", params_mm_rec) mm_out = nest.Create("multimeter", params_mm_out) -sr = nest.Create("spike_recorder", params_sr) +sr_in = nest.Create("spike_recorder", params_sr_in) +sr_rec = nest.Create("spike_recorder", params_sr_rec) wr = nest.Create("weight_recorder", params_wr) nrns_rec_record = nrns_rec[:n_record] @@ -280,19 +303,19 @@ params_common_syn_eprop = { "optimizer": { "type": "gradient_descent", # algorithm to optimize the weights - "batch_size": n_batch, + "batch_size": 1, "eta": 1e-4, # learning rate + "optimize_each_step": False, # call optimizer every time step (True) or once per spike (False); both + # yield same results for gradient descent, False offers speed-up "Wmin": -100.0, # pA, minimal limit of the synaptic weights "Wmax": 100.0, # pA, maximal limit of the synaptic weights }, - "average_gradient": False, # if True, average the gradient over the learning window "weight_recorder": wr, } params_syn_base = { - "synapse_model": "eprop_synapse_bsshslm_2020", + "synapse_model": "eprop_synapse", "delay": duration["step"], # ms, dendritic delay - "tau_m_readout": params_nrn_out["tau_m"], # ms, for technical reasons pass readout neuron membrane time constant } params_syn_in = params_syn_base.copy() @@ -305,11 +328,17 @@ params_syn_out["weight"] = weights_rec_out params_syn_feedback = { - "synapse_model": "eprop_learning_signal_connection_bsshslm_2020", + "synapse_model": "eprop_learning_signal_connection", "delay": duration["step"], "weight": weights_out_rec, } +params_syn_learning_window = { + "synapse_model": "rate_connection_delayed", + "delay": duration["step"], + "receptor_type": 1, # receptor type over which readout neuron receives learning window signal +} + params_syn_rate_target = { "synapse_model": "rate_connection_delayed", "delay": duration["step"], @@ -323,7 +352,7 @@ #################### -nest.SetDefaults("eprop_synapse_bsshslm_2020", params_common_syn_eprop) +nest.SetDefaults("eprop_synapse", params_common_syn_eprop) nest.Connect(gen_spk_in, nrns_in, params_conn_one_to_one, params_syn_static) # connection 1 nest.Connect(nrns_in, nrns_rec, params_conn_all_to_all, params_syn_in) # connection 2 @@ -331,8 +360,10 @@ nest.Connect(nrns_rec, nrns_out, params_conn_all_to_all, params_syn_out) # connection 4 nest.Connect(nrns_out, nrns_rec, params_conn_all_to_all, params_syn_feedback) # connection 5 nest.Connect(gen_rate_target, nrns_out, params_conn_one_to_one, params_syn_rate_target) # connection 6 +nest.Connect(gen_learning_window, nrns_out, params_conn_all_to_all, params_syn_learning_window) # connection 7 -nest.Connect(nrns_in + nrns_rec, sr, params_conn_all_to_all, params_syn_static) +nest.Connect(nrns_in, sr_in, params_conn_all_to_all, params_syn_static) +nest.Connect(nrns_rec, sr_rec, params_conn_all_to_all, params_syn_static) nest.Connect(mm_rec, nrns_rec_record, params_conn_all_to_all, params_syn_static) nest.Connect(mm_out, nrns_out, params_conn_all_to_all, params_syn_static) @@ -348,7 +379,6 @@ dtype_in_spks = np.float32 # data type of input spikes - for reproducing TF results set to np.float32 input_spike_bools = (np.random.rand(steps["sequence"], n_in) < input_spike_prob).swapaxes(0, 1) -input_spike_bools[:, 0] = 0 # remove spikes in 0th time step of every sequence for technical reasons sequence_starts = np.arange(0.0, duration["task"], duration["sequence"]) + duration["offset_gen"] params_gen_spk_in = [] @@ -390,13 +420,29 @@ def generate_superimposed_sines(steps_sequence, periods): params_gen_rate_target = { "amplitude_times": np.arange(0.0, duration["task"], duration["step"]) + duration["total_offset"], - "amplitude_values": np.tile(target_signal, n_iter * n_batch), + "amplitude_values": np.tile(target_signal, n_iter * group_size), } #################### nest.SetStatus(gen_rate_target, params_gen_rate_target) +# %% ########################################################################################################### +# Create learning window +# ~~~~~~~~~~~~~~~~~~~~~~ +# Custom learning windows, in which the network learns, can be defined with an additional signal. The error +# signal is internally multiplied with this learning window signal. Passing a learning window signal of value 1 +# opens the learning window while passing a value of 0 closes it. + +params_gen_learning_window = { + "amplitude_times": [duration["total_offset"]], + "amplitude_values": [1.0], +} + +#################### + +nest.SetStatus(gen_learning_window, params_gen_learning_window) + # %% ########################################################################################################### # Force final update # ~~~~~~~~~~~~~~~~~~ @@ -459,7 +505,8 @@ def get_weights(pop_pre, pop_post): events_mm_rec = mm_rec.get("events") events_mm_out = mm_out.get("events") -events_sr = sr.get("events") +events_sr_in = sr_in.get("events") +events_sr_rec = sr_rec.get("events") events_wr = wr.get("events") # %% ########################################################################################################### @@ -470,9 +517,15 @@ def get_weights(pop_pre, pop_post): readout_signal = events_mm_out["readout_signal"] target_signal = events_mm_out["target_signal"] +senders = events_mm_out["senders"] + +readout_signal = np.array([readout_signal[senders == i] for i in set(senders)]) +target_signal = np.array([target_signal[senders == i] for i in set(senders)]) -error = (readout_signal - target_signal) ** 2 -loss = 0.5 * np.add.reduceat(error, np.arange(0, steps["task"], steps["sequence"])) +readout_signal = readout_signal.reshape((n_out, n_iter, group_size, steps["sequence"])) +target_signal = target_signal.reshape((n_out, n_iter, group_size, steps["sequence"])) + +loss = 0.5 * np.mean(np.sum((readout_signal - target_signal) ** 2, axis=3), axis=(0, 2)) # %% ########################################################################################################### # Plot results @@ -492,7 +545,6 @@ def get_weights(pop_pre, pop_post): plt.rcParams.update( { - "font.sans-serif": "Arial", "axes.spines.right": False, "axes.spines.top": False, "axes.prop_cycle": cycler(color=[colors["blue"], colors["red"]]), @@ -505,6 +557,7 @@ def get_weights(pop_pre, pop_post): # We begin with a plot visualizing the training error of the network: the loss plotted against the iterations. fig, ax = plt.subplots() +fig.suptitle("Training error") ax.plot(range(1, n_iter + 1), loss) ax.set_ylabel(r"$E = \frac{1}{2} \sum_{t,k} \left( y_k^t -y_k^{*,t}\right)^2$") @@ -531,11 +584,10 @@ def plot_recordable(ax, events, recordable, ylabel, xlims): ax.set_ylim(np.min(events[recordable]) - margin, np.max(events[recordable]) + margin) -def plot_spikes(ax, events, nrns, ylabel, xlims): +def plot_spikes(ax, events, ylabel, xlims): idc_times = (events["times"] > xlims[0]) & (events["times"] < xlims[1]) - idc_sender = np.isin(events["senders"][idc_times], nrns.tolist()) - senders_subset = events["senders"][idc_times][idc_sender] - times_subset = events["times"][idc_times][idc_sender] + senders_subset = events["senders"][idc_times] + times_subset = events["times"][idc_times] ax.scatter(times_subset, senders_subset, s=0.1) ax.set_ylabel(ylabel) @@ -543,11 +595,15 @@ def plot_spikes(ax, events, nrns, ylabel, xlims): ax.set_ylim(np.min(senders_subset) - margin, np.max(senders_subset) + margin) -for xlims in [(0, steps["sequence"]), (steps["task"] - steps["sequence"], steps["task"])]: +for title, xlims in zip( + ["Dynamic variables before training", "Dynamic variables after training"], + [(0, steps["sequence"]), (steps["task"] - steps["sequence"], steps["task"])], +): fig, axs = plt.subplots(9, 1, sharex=True, figsize=(6, 8), gridspec_kw={"hspace": 0.4, "left": 0.2}) + fig.suptitle(title) - plot_spikes(axs[0], events_sr, nrns_in, r"$z_i$" + "\n", xlims) - plot_spikes(axs[1], events_sr, nrns_rec, r"$z_j$" + "\n", xlims) + plot_spikes(axs[0], events_sr_in, r"$z_i$" + "\n", xlims) + plot_spikes(axs[1], events_sr_rec, r"$z_j$" + "\n", xlims) plot_recordable(axs[2], events_mm_rec, "V_m", r"$v_j$" + "\n(mV)", xlims) plot_recordable(axs[3], events_mm_rec, "surrogate_gradient", r"$\psi_j$" + "\n", xlims) @@ -572,7 +628,10 @@ def plot_spikes(ax, events, nrns, ylabel, xlims): # the first time step and we add the initial weights manually. -def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabel): +def plot_weight_time_course(ax, events, nrns_weight_record, label, ylabel): + sender_label, target_label = label.split("_") + nrns_senders = nrns_weight_record[sender_label] + nrns_targets = nrns_weight_record[target_label] for sender in nrns_senders.tolist(): for target in nrns_targets.tolist(): idc_syn = (events["senders"] == sender) & (events["targets"] == target) @@ -589,12 +648,17 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4)) +fig.suptitle("Weight time courses") -plot_weight_time_course(axs[0], events_wr, nrns_in[:n_record_w], nrns_rec[:n_record_w], "in_rec", r"$W_\text{in}$ (pA)") -plot_weight_time_course( - axs[1], events_wr, nrns_rec[:n_record_w], nrns_rec[:n_record_w], "rec_rec", r"$W_\text{rec}$ (pA)" -) -plot_weight_time_course(axs[2], events_wr, nrns_rec[:n_record_w], nrns_out, "rec_out", r"$W_\text{out}$ (pA)") +nrns_weight_record = { + "in": nrns_in[:n_record_w], + "rec": nrns_rec[:n_record_w], + "out": nrns_out, +} + +plot_weight_time_course(axs[0], events_wr, nrns_weight_record, "in_rec", r"$W_\text{in}$ (pA)") +plot_weight_time_course(axs[1], events_wr, nrns_weight_record, "rec_rec", r"$W_\text{rec}$ (pA)") +plot_weight_time_course(axs[2], events_wr, nrns_weight_record, "rec_out", r"$W_\text{out}$ (pA)") axs[-1].set_xlabel(r"$t$ (ms)") axs[-1].set_xlim(0, steps["task"]) @@ -614,6 +678,7 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe ) fig, axs = plt.subplots(3, 2, sharex="col", sharey="row") +fig.suptitle("Weight matrices") all_w_extrema = [] @@ -636,8 +701,8 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe axs[2, 0].set_ylabel("readout\nneurons") fig.align_ylabels(axs[:, 0]) -axs[0, 0].text(0.5, 1.1, "pre-training", transform=axs[0, 0].transAxes, ha="center") -axs[0, 1].text(0.5, 1.1, "post-training", transform=axs[0, 1].transAxes, ha="center") +axs[0, 0].text(0.5, 1.1, "before training", transform=axs[0, 0].transAxes, ha="center") +axs[0, 1].text(0.5, 1.1, "after training", transform=axs[0, 1].transAxes, ha="center") axs[2, 0].yaxis.get_major_locator().set_params(integer=True) diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves_bsshslm_2020.png b/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves_bsshslm_2020.png new file mode 100644 index 0000000000..2e7dd1f4bb Binary files /dev/null and b/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves_bsshslm_2020.png differ diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves_bsshslm_2020.py b/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves_bsshslm_2020.py new file mode 100644 index 0000000000..a1375ee279 --- /dev/null +++ b/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves_bsshslm_2020.py @@ -0,0 +1,687 @@ +# -*- coding: utf-8 -*- +# +# eprop_supervised_regression_sine-waves_bsshslm_2020.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +r""" +Tutorial on learning to generate sine waves with e-prop after Bellec et al. (2020) +---------------------------------------------------------------------------------- + +Training a regression model using supervised e-prop plasticity to generate sine waves + +Description +~~~~~~~~~~~ + +This script demonstrates supervised learning of a regression task with a recurrent spiking neural network that +is equipped with the eligibility propagation (e-prop) plasticity mechanism by Bellec et al. [1]_. + +This type of learning is demonstrated at the proof-of-concept task in [1]_. We based this script on their +TensorFlow script given in [2]_. + +In this task, the network learns to generate an arbitrary N-dimensional temporal pattern. Here, the +network learns to reproduce with its overall spiking activity a one-dimensional, one-second-long target signal +which is a superposition of four sine waves of different amplitudes, phases, and periods. + +.. image:: eprop_supervised_regression_sine-waves_bsshslm_2020.png + :width: 70 % + :alt: Schematic of network architecture. Same as Figure 1 in the code. + :align: center + +Learning in the neural network model is achieved by optimizing the connection weights with e-prop plasticity. +This plasticity rule requires a specific network architecture depicted in Figure 1. The neural network model +consists of a recurrent network that receives frozen noise input from Poisson generators and projects onto one +readout neuron. The readout neuron compares the network signal :math:`y` with the teacher target signal +:math:`y*`, which it receives from a rate generator. In scenarios with multiple readout neurons, each individual +readout signal denoted as :math:`y_k` is compared with a corresponding target signal represented as +:math:`y_k^*`. The network's training error is assessed by employing a mean-squared error loss. + +Details on the event-based NEST implementation of e-prop can be found in [3]_. + +References +~~~~~~~~~~ + +.. [1] Bellec G, Scherr F, Subramoney F, Hajek E, Salaj D, Legenstein R, Maass W (2020). A solution to the + learning dilemma for recurrent networks of spiking neurons. Nature Communications, 11:3625. + https://doi.org/10.1038/s41467-020-17236-y + +.. [2] https://github.com/IGITUGraz/eligibility_propagation/blob/master/Figure_3_and_S7_e_prop_tutorials/tutorial_pattern_generation.py + +.. [3] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Plesser HE, + Dahmen D, Bolten M, Van Albada SJ*, Diesmann M*. Event-based + implementation of eligibility propagation (in preparation) + +""" # pylint: disable=line-too-long # noqa: E501 + +# %% ########################################################################################################### +# Import libraries +# ~~~~~~~~~~~~~~~~ +# We begin by importing all libraries required for the simulation, analysis, and visualization. + +import matplotlib as mpl +import matplotlib.pyplot as plt +import nest +import numpy as np +from cycler import cycler +from IPython.display import Image + +# %% ########################################################################################################### +# Schematic of network architecture +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# This figure, identical to the one in the description, shows the required network architecture in the center, +# the input and output of the pattern generation task above, and lists of the required NEST device, neuron, and +# synapse models below. The connections that must be established are numbered 1 to 6. + +try: + Image(filename="./eprop_supervised_regression_sine-waves_bsshslm_2020.png") +except Exception: + pass + +# %% ########################################################################################################### +# Setup +# ~~~~~ + +# %% ########################################################################################################### +# Initialize random generator +# ........................... +# We seed the numpy random generator, which will generate random initial weights as well as random input and +# output. + +rng_seed = 1 # numpy random seed +np.random.seed(rng_seed) # fix numpy random seed + +# %% ########################################################################################################### +# Define timing of task +# ..................... +# The task's temporal structure is then defined, once as time steps and once as durations in milliseconds. +# Increasing the number of iterations enhances learning performance. + +batch_size = 1 # batch size, 1 in reference [2] +n_iter = 200 # number of iterations, 2000 in reference [2] + +steps = { + "sequence": 1000, # time steps of one full sequence +} + +steps["learning_window"] = steps["sequence"] # time steps of window with non-zero learning signals +steps["task"] = n_iter * batch_size * steps["sequence"] # time steps of task + +steps.update( + { + "offset_gen": 1, # offset since generator signals start from time step 1 + "delay_in_rec": 1, # connection delay between input and recurrent neurons + "delay_rec_out": 1, # connection delay between recurrent and output neurons + "delay_out_norm": 1, # connection delay between output neurons for normalization + "extension_sim": 1, # extra time step to close right-open simulation time interval in Simulate() + } +) + +steps["delays"] = steps["delay_in_rec"] + steps["delay_rec_out"] + steps["delay_out_norm"] # time steps of delays + +steps["total_offset"] = steps["offset_gen"] + steps["delays"] # time steps of total offset + +steps["sim"] = steps["task"] + steps["total_offset"] + steps["extension_sim"] # time steps of simulation + +duration = {"step": 1.0} # ms, temporal resolution of the simulation + +duration.update({key: value * duration["step"] for key, value in steps.items()}) # ms, durations + +# %% ########################################################################################################### +# Set up simulation +# ................. +# As last step of the setup, we reset the NEST kernel to remove all existing NEST simulation settings and +# objects and set some NEST kernel parameters, some of which are e-prop-related. + +params_setup = { + "eprop_learning_window": duration["learning_window"], + "eprop_reset_neurons_on_update": True, # if True, reset dynamic variables at start of each update interval + "eprop_update_interval": duration["sequence"], # ms, time interval for updating the synaptic weights + "print_time": False, # if True, print time progress bar during simulation, set False if run as code cell + "resolution": duration["step"], + "total_num_virtual_procs": 1, # number of virtual processes, set in case of distributed computing +} + +#################### + +nest.ResetKernel() +nest.set(**params_setup) + +# %% ########################################################################################################### +# Create neurons +# ~~~~~~~~~~~~~~ +# We proceed by creating a certain number of input, recurrent, and readout neurons and setting their parameters. +# Additionally, we already create an input spike generator and an output target rate generator, which we will +# configure later. + +n_in = 100 # number of input neurons +n_rec = 100 # number of recurrent neurons +n_out = 1 # number of readout neurons + +params_nrn_out = { + "C_m": 1.0, # pF, membrane capacitance - takes effect only if neurons get current input (here not the case) + "E_L": 0.0, # mV, leak / resting membrane potential + "I_e": 0.0, # pA, external current input + "loss": "mean_squared_error", # loss function + "regular_spike_arrival": False, # If True, input spikes arrive at end of time step, if False at beginning + "tau_m": 30.0, # ms, membrane time constant + "V_m": 0.0, # mV, initial value of the membrane voltage +} + +params_nrn_rec = { + "beta": 1.0, # width scaling of the pseudo-derivative + "C_m": 1.0, + "c_reg": 300.0, # coefficient of firing rate regularization + "E_L": 0.0, + "f_target": 10.0, # spikes/s, target firing rate for firing rate regularization + "gamma": 0.3, # height scaling of the pseudo-derivative + "I_e": 0.0, + "regular_spike_arrival": False, + "surrogate_gradient_function": "piecewise_linear", # surrogate gradient / pseudo-derivative function + "t_ref": 0.0, # ms, duration of refractory period + "tau_m": 30.0, + "V_m": 0.0, + "V_th": 0.03, # mV, spike threshold membrane voltage +} + +# factors from the original pseudo-derivative definition are incorporated into the parameters +params_nrn_rec["gamma"] /= params_nrn_rec["V_th"] +params_nrn_rec["beta"] /= np.abs(params_nrn_rec["V_th"]) # prefactor is inside abs in the original definition + +#################### + +# Intermediate parrot neurons required between input spike generators and recurrent neurons, +# since devices cannot establish plastic synapses for technical reasons + +gen_spk_in = nest.Create("spike_generator", n_in) +nrns_in = nest.Create("parrot_neuron", n_in) + +# The suffix _bsshslm_2020 follows the NEST convention to indicate in the model name the paper +# that introduced it by the first letter of the authors' last names and the publication year. + +nrns_rec = nest.Create("eprop_iaf_bsshslm_2020", n_rec, params_nrn_rec) +nrns_out = nest.Create("eprop_readout_bsshslm_2020", n_out, params_nrn_out) +gen_rate_target = nest.Create("step_rate_generator", n_out) + + +# %% ########################################################################################################### +# Create recorders +# ~~~~~~~~~~~~~~~~ +# We also create recorders, which, while not required for the training, will allow us to track various dynamic +# variables of the neurons, spikes, and changes in synaptic weights. To save computing time and memory, the +# recorders, the recorded variables, neurons, and synapses can be limited to the ones relevant to the +# experiment, and the recording interval can be increased (see the documentation on the specific recorders). By +# default, recordings are stored in memory but can also be written to file. + +n_record = 1 # number of neurons to record dynamic variables from - this script requires n_record >= 1 +n_record_w = 5 # number of senders and targets to record weights from - this script requires n_record_w >=1 + +if n_record == 0 or n_record_w == 0: + raise ValueError("n_record and n_record_w >= 1 required") + +params_mm_rec = { + "interval": duration["step"], # interval between two recorded time points + "record_from": ["V_m", "surrogate_gradient", "learning_signal"], # dynamic variables to record + "start": duration["offset_gen"] + duration["delay_in_rec"], # start time of recording + "stop": duration["offset_gen"] + duration["delay_in_rec"] + duration["task"], # stop time of recording + "label": "multimeter_rec", +} + +params_mm_out = { + "interval": duration["step"], + "record_from": ["V_m", "readout_signal", "readout_signal_unnorm", "target_signal", "error_signal"], + "start": duration["total_offset"], + "stop": duration["total_offset"] + duration["task"], + "label": "multimeter_out", +} + +params_wr = { + "senders": nrns_in[:n_record_w] + nrns_rec[:n_record_w], # limit senders to subsample weights to record + "targets": nrns_rec[:n_record_w] + nrns_out, # limit targets to subsample weights to record from + "start": duration["total_offset"], + "stop": duration["total_offset"] + duration["task"], + "label": "weight_recorder", +} + +params_sr_in = { + "start": duration["offset_gen"], + "stop": duration["total_offset"] + duration["task"], + "label": "spike_recorder_in", +} + +params_sr_rec = { + "start": duration["offset_gen"], + "stop": duration["total_offset"] + duration["task"], + "label": "spike_recorder_rec", +} + +#################### + +mm_rec = nest.Create("multimeter", params_mm_rec) +mm_out = nest.Create("multimeter", params_mm_out) +sr_in = nest.Create("spike_recorder", params_sr_in) +sr_rec = nest.Create("spike_recorder", params_sr_rec) +wr = nest.Create("weight_recorder", params_wr) + +nrns_rec_record = nrns_rec[:n_record] + +# %% ########################################################################################################### +# Create connections +# ~~~~~~~~~~~~~~~~~~ +# Now, we define the connectivity and set up the synaptic parameters, with the synaptic weights drawn from +# normal distributions. After these preparations, we establish the enumerated connections of the core network, +# as well as additional connections to the recorders. + +params_conn_all_to_all = {"rule": "all_to_all", "allow_autapses": False} +params_conn_one_to_one = {"rule": "one_to_one"} + +dtype_weights = np.float32 # data type of weights - for reproducing TF results set to np.float32 +weights_in_rec = np.array(np.random.randn(n_in, n_rec).T / np.sqrt(n_in), dtype=dtype_weights) +weights_rec_rec = np.array(np.random.randn(n_rec, n_rec).T / np.sqrt(n_rec), dtype=dtype_weights) +np.fill_diagonal(weights_rec_rec, 0.0) # since no autapses set corresponding weights to zero +weights_rec_out = np.array(np.random.randn(n_rec, n_out).T / np.sqrt(n_rec), dtype=dtype_weights) +weights_out_rec = np.array(np.random.randn(n_rec, n_out) / np.sqrt(n_rec), dtype=dtype_weights) + +params_common_syn_eprop = { + "optimizer": { + "type": "gradient_descent", # algorithm to optimize the weights + "batch_size": batch_size, + "eta": 1e-4, # learning rate + "Wmin": -100.0, # pA, minimal limit of the synaptic weights + "Wmax": 100.0, # pA, maximal limit of the synaptic weights + }, + "average_gradient": False, # if True, average the gradient over the learning window + "weight_recorder": wr, +} + +params_syn_base = { + "synapse_model": "eprop_synapse_bsshslm_2020", + "delay": duration["step"], # ms, dendritic delay + "tau_m_readout": params_nrn_out["tau_m"], # ms, for technical reasons pass readout neuron membrane time constant +} + +params_syn_in = params_syn_base.copy() +params_syn_in["weight"] = weights_in_rec # pA, initial values for the synaptic weights + +params_syn_rec = params_syn_base.copy() +params_syn_rec["weight"] = weights_rec_rec + +params_syn_out = params_syn_base.copy() +params_syn_out["weight"] = weights_rec_out + +params_syn_feedback = { + "synapse_model": "eprop_learning_signal_connection_bsshslm_2020", + "delay": duration["step"], + "weight": weights_out_rec, +} + +params_syn_rate_target = { + "synapse_model": "rate_connection_delayed", + "delay": duration["step"], + "receptor_type": 2, # receptor type over which readout neuron receives target signal +} + +params_syn_static = { + "synapse_model": "static_synapse", + "delay": duration["step"], +} + +#################### + +nest.SetDefaults("eprop_synapse_bsshslm_2020", params_common_syn_eprop) + +nest.Connect(gen_spk_in, nrns_in, params_conn_one_to_one, params_syn_static) # connection 1 +nest.Connect(nrns_in, nrns_rec, params_conn_all_to_all, params_syn_in) # connection 2 +nest.Connect(nrns_rec, nrns_rec, params_conn_all_to_all, params_syn_rec) # connection 3 +nest.Connect(nrns_rec, nrns_out, params_conn_all_to_all, params_syn_out) # connection 4 +nest.Connect(nrns_out, nrns_rec, params_conn_all_to_all, params_syn_feedback) # connection 5 +nest.Connect(gen_rate_target, nrns_out, params_conn_one_to_one, params_syn_rate_target) # connection 6 + +nest.Connect(nrns_in, sr_in, params_conn_all_to_all, params_syn_static) +nest.Connect(nrns_rec, sr_rec, params_conn_all_to_all, params_syn_static) + +nest.Connect(mm_rec, nrns_rec_record, params_conn_all_to_all, params_syn_static) +nest.Connect(mm_out, nrns_out, params_conn_all_to_all, params_syn_static) + +# %% ########################################################################################################### +# Create input +# ~~~~~~~~~~~~ +# We generate some frozen Poisson spike noise of a fixed rate that is repeated in each iteration and feed these +# spike times to the previously created input spike generator. The network will use these spike times as a +# temporal backbone for encoding the target signal into its recurrent spiking activity. + +input_spike_prob = 0.05 # spike probability of frozen input noise +dtype_in_spks = np.float32 # data type of input spikes - for reproducing TF results set to np.float32 + +input_spike_bools = (np.random.rand(steps["sequence"], n_in) < input_spike_prob).swapaxes(0, 1) +input_spike_bools[:, 0] = 0 # remove spikes in 0th time step of every sequence for technical reasons + +sequence_starts = np.arange(0.0, duration["task"], duration["sequence"]) + duration["offset_gen"] +params_gen_spk_in = [] +for input_spike_bool in input_spike_bools: + input_spike_times = np.arange(0.0, duration["sequence"], duration["step"])[input_spike_bool] + input_spike_times_all = [input_spike_times + start for start in sequence_starts] + params_gen_spk_in.append({"spike_times": np.hstack(input_spike_times_all).astype(dtype_in_spks)}) + +#################### + +nest.SetStatus(gen_spk_in, params_gen_spk_in) + +# %% ########################################################################################################### +# Create output +# ~~~~~~~~~~~~~ +# Then, as a superposition of four sine waves with various durations, amplitudes, and phases, we construct a +# one-second target signal. This signal, like the input, is repeated for all iterations and fed into the rate +# generator that was previously created. + + +def generate_superimposed_sines(steps_sequence, periods): + n_sines = len(periods) + + amplitudes = np.random.uniform(low=0.5, high=2.0, size=n_sines) + phases = np.random.uniform(low=0.0, high=2.0 * np.pi, size=n_sines) + + sines = [ + A * np.sin(np.linspace(phi, phi + 2.0 * np.pi * (steps_sequence // T), steps_sequence)) + for A, phi, T in zip(amplitudes, phases, periods) + ] + + superposition = sum(sines) + superposition -= superposition[0] + superposition /= max(np.abs(superposition).max(), 1e-6) + return superposition + + +target_signal = generate_superimposed_sines(steps["sequence"], [1000, 500, 333, 200]) # periods in steps + +params_gen_rate_target = { + "amplitude_times": np.arange(0.0, duration["task"], duration["step"]) + duration["total_offset"], + "amplitude_values": np.tile(target_signal, n_iter * batch_size), +} + +#################### + +nest.SetStatus(gen_rate_target, params_gen_rate_target) + +# %% ########################################################################################################### +# Force final update +# ~~~~~~~~~~~~~~~~~~ +# Synapses only get active, that is, the correct weight update calculated and applied, when they transmit a +# spike. To still be able to read out the correct weights at the end of the simulation, we force spiking of the +# presynaptic neuron and thus an update of all synapses, including those that have not transmitted a spike in +# the last update interval, by sending a strong spike to all neurons that form the presynaptic side of an eprop +# synapse. This step is required purely for technical reasons. + +gen_spk_final_update = nest.Create("spike_generator", 1, {"spike_times": [duration["task"] + duration["delays"]]}) + +nest.Connect(gen_spk_final_update, nrns_in + nrns_rec, "all_to_all", {"weight": 1000.0}) + +# %% ########################################################################################################### +# Read out pre-training weights +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Before we begin training, we read out the initial weight matrices so that we can eventually compare them to +# the optimized weights. + + +def get_weights(pop_pre, pop_post): + conns = nest.GetConnections(pop_pre, pop_post).get(["source", "target", "weight"]) + conns["senders"] = np.array(conns["source"]) - np.min(conns["source"]) + conns["targets"] = np.array(conns["target"]) - np.min(conns["target"]) + + conns["weight_matrix"] = np.zeros((len(pop_post), len(pop_pre))) + conns["weight_matrix"][conns["targets"], conns["senders"]] = conns["weight"] + return conns + + +weights_pre_train = { + "in_rec": get_weights(nrns_in, nrns_rec), + "rec_rec": get_weights(nrns_rec, nrns_rec), + "rec_out": get_weights(nrns_rec, nrns_out), +} + +# %% ########################################################################################################### +# Simulate +# ~~~~~~~~ +# We train the network by simulating for a set simulation time, determined by the number of iterations and the +# batch size and the length of one sequence. + +nest.Simulate(duration["sim"]) + +# %% ########################################################################################################### +# Read out post-training weights +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# After the training, we can read out the optimized final weights. + +weights_post_train = { + "in_rec": get_weights(nrns_in, nrns_rec), + "rec_rec": get_weights(nrns_rec, nrns_rec), + "rec_out": get_weights(nrns_rec, nrns_out), +} + +# %% ########################################################################################################### +# Read out recorders +# ~~~~~~~~~~~~~~~~~~ +# We can also retrieve the recorded history of the dynamic variables and weights, as well as detected spikes. + +events_mm_rec = mm_rec.get("events") +events_mm_out = mm_out.get("events") +events_sr_in = sr_in.get("events") +events_sr_rec = sr_rec.get("events") +events_wr = wr.get("events") + +# %% ########################################################################################################### +# Evaluate training error +# ~~~~~~~~~~~~~~~~~~~~~~~ +# We evaluate the network's training error by calculating a loss - in this case, the mean squared error between +# the integrated recurrent network activity and the target rate. + +readout_signal = events_mm_out["readout_signal"] +target_signal = events_mm_out["target_signal"] +senders = events_mm_out["senders"] + +readout_signal = np.array([readout_signal[senders == i] for i in set(senders)]) +target_signal = np.array([target_signal[senders == i] for i in set(senders)]) + +readout_signal = readout_signal.reshape((n_out, n_iter, batch_size, steps["sequence"])) +target_signal = target_signal.reshape((n_out, n_iter, batch_size, steps["sequence"])) + +loss = 0.5 * np.mean(np.sum((readout_signal - target_signal) ** 2, axis=3), axis=(0, 2)) + +# %% ########################################################################################################### +# Plot results +# ~~~~~~~~~~~~ +# Then, we plot a series of plots. + +do_plotting = True # if True, plot the results + +if not do_plotting: + exit() + +colors = { + "blue": "#2854c5ff", + "red": "#e04b40ff", + "white": "#ffffffff", +} + +plt.rcParams.update( + { + "axes.spines.right": False, + "axes.spines.top": False, + "axes.prop_cycle": cycler(color=[colors["blue"], colors["red"]]), + } +) + +# %% ########################################################################################################### +# Plot training error +# ................... +# We begin with a plot visualizing the training error of the network: the loss plotted against the iterations. + +fig, ax = plt.subplots() +fig.suptitle("Training error") + +ax.plot(range(1, n_iter + 1), loss) +ax.set_ylabel(r"$E = \frac{1}{2} \sum_{t,k} \left( y_k^t -y_k^{*,t}\right)^2$") +ax.set_xlabel("training iteration") +ax.set_xlim(1, n_iter) +ax.xaxis.get_major_locator().set_params(integer=True) + +fig.tight_layout() + +# %% ########################################################################################################### +# Plot spikes and dynamic variables +# ................................. +# This plotting routine shows how to plot all of the recorded dynamic variables and spikes across time. We take +# one snapshot in the first iteration and one snapshot at the end. + + +def plot_recordable(ax, events, recordable, ylabel, xlims): + for sender in set(events["senders"]): + idc_sender = events["senders"] == sender + idc_times = (events["times"][idc_sender] > xlims[0]) & (events["times"][idc_sender] < xlims[1]) + ax.plot(events["times"][idc_sender][idc_times], events[recordable][idc_sender][idc_times], lw=0.5) + ax.set_ylabel(ylabel) + margin = np.abs(np.max(events[recordable]) - np.min(events[recordable])) * 0.1 + ax.set_ylim(np.min(events[recordable]) - margin, np.max(events[recordable]) + margin) + + +def plot_spikes(ax, events, ylabel, xlims): + idc_times = (events["times"] > xlims[0]) & (events["times"] < xlims[1]) + senders_subset = events["senders"][idc_times] + times_subset = events["times"][idc_times] + + ax.scatter(times_subset, senders_subset, s=0.1) + ax.set_ylabel(ylabel) + margin = np.abs(np.max(senders_subset) - np.min(senders_subset)) * 0.1 + ax.set_ylim(np.min(senders_subset) - margin, np.max(senders_subset) + margin) + + +for title, xlims in zip( + ["Dynamic variables before training", "Dynamic variables after training"], + [(0, steps["sequence"]), (steps["task"] - steps["sequence"], steps["task"])], +): + fig, axs = plt.subplots(9, 1, sharex=True, figsize=(6, 8), gridspec_kw={"hspace": 0.4, "left": 0.2}) + fig.suptitle(title) + + plot_spikes(axs[0], events_sr_in, r"$z_i$" + "\n", xlims) + plot_spikes(axs[1], events_sr_rec, r"$z_j$" + "\n", xlims) + + plot_recordable(axs[2], events_mm_rec, "V_m", r"$v_j$" + "\n(mV)", xlims) + plot_recordable(axs[3], events_mm_rec, "surrogate_gradient", r"$\psi_j$" + "\n", xlims) + plot_recordable(axs[4], events_mm_rec, "learning_signal", r"$L_j$" + "\n(pA)", xlims) + + plot_recordable(axs[5], events_mm_out, "V_m", r"$v_k$" + "\n(mV)", xlims) + plot_recordable(axs[6], events_mm_out, "target_signal", r"$y^*_k$" + "\n", xlims) + plot_recordable(axs[7], events_mm_out, "readout_signal", r"$y_k$" + "\n", xlims) + plot_recordable(axs[8], events_mm_out, "error_signal", r"$y_k-y^*_k$" + "\n", xlims) + + axs[-1].set_xlabel(r"$t$ (ms)") + axs[-1].set_xlim(*xlims) + + fig.align_ylabels() + +# %% ########################################################################################################### +# Plot weight time courses +# ........................ +# Similarly, we can plot the weight histories. Note that the weight recorder, attached to the synapses, works +# differently than the other recorders. Since synapses only get activated when they transmit a spike, the weight +# recorder only records the weight in those moments. That is why the first weight registrations do not start in +# the first time step and we add the initial weights manually. + + +def plot_weight_time_course(ax, events, nrns_weight_record, label, ylabel): + sender_label, target_label = label.split("_") + nrns_senders = nrns_weight_record[sender_label] + nrns_targets = nrns_weight_record[target_label] + for sender in nrns_senders.tolist(): + for target in nrns_targets.tolist(): + idc_syn = (events["senders"] == sender) & (events["targets"] == target) + idc_syn_pre = (weights_pre_train[label]["source"] == sender) & ( + weights_pre_train[label]["target"] == target + ) + + times = [0.0] + events["times"][idc_syn].tolist() + weights = [weights_pre_train[label]["weight"][idc_syn_pre]] + events["weights"][idc_syn].tolist() + + ax.step(times, weights, c=colors["blue"]) + ax.set_ylabel(ylabel) + ax.set_ylim(-0.6, 0.6) + + +fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4)) +fig.suptitle("Weight time courses") + +nrns_weight_record = { + "in": nrns_in[:n_record_w], + "rec": nrns_rec[:n_record_w], + "out": nrns_out, +} + +plot_weight_time_course(axs[0], events_wr, nrns_weight_record, "in_rec", r"$W_\text{in}$ (pA)") +plot_weight_time_course(axs[1], events_wr, nrns_weight_record, "rec_rec", r"$W_\text{rec}$ (pA)") +plot_weight_time_course(axs[2], events_wr, nrns_weight_record, "rec_out", r"$W_\text{out}$ (pA)") + +axs[-1].set_xlabel(r"$t$ (ms)") +axs[-1].set_xlim(0, steps["task"]) + +fig.align_ylabels() +fig.tight_layout() + +# %% ########################################################################################################### +# Plot weight matrices +# .................... +# If one is not interested in the time course of the weights, it is possible to read out only the initial and +# final weights, which requires less computing time and memory than the weight recorder approach. Here, we plot +# the corresponding weight matrices before and after the optimization. + +cmap = mpl.colors.LinearSegmentedColormap.from_list( + "cmap", ((0.0, colors["blue"]), (0.5, colors["white"]), (1.0, colors["red"])) +) + +fig, axs = plt.subplots(3, 2, sharex="col", sharey="row") +fig.suptitle("Weight matrices") + +all_w_extrema = [] + +for k in weights_pre_train.keys(): + w_pre = weights_pre_train[k]["weight"] + w_post = weights_post_train[k]["weight"] + all_w_extrema.append([np.min(w_pre), np.max(w_pre), np.min(w_post), np.max(w_post)]) + +args = {"cmap": cmap, "vmin": np.min(all_w_extrema), "vmax": np.max(all_w_extrema)} + +for i, weights in zip([0, 1], [weights_pre_train, weights_post_train]): + axs[0, i].pcolormesh(weights["in_rec"]["weight_matrix"].T, **args) + axs[1, i].pcolormesh(weights["rec_rec"]["weight_matrix"], **args) + cmesh = axs[2, i].pcolormesh(weights["rec_out"]["weight_matrix"], **args) + + axs[2, i].set_xlabel("recurrent\nneurons") + +axs[0, 0].set_ylabel("input\nneurons") +axs[1, 0].set_ylabel("recurrent\nneurons") +axs[2, 0].set_ylabel("readout\nneurons") +fig.align_ylabels(axs[:, 0]) + +axs[0, 0].text(0.5, 1.1, "before training", transform=axs[0, 0].transAxes, ha="center") +axs[0, 1].text(0.5, 1.1, "after training", transform=axs[0, 1].transAxes, ha="center") + +axs[2, 0].yaxis.get_major_locator().set_params(integer=True) + +cbar = plt.colorbar(cmesh, cax=axs[1, 1].inset_axes([1.1, 0.2, 0.05, 0.8]), label="weight (pA)") + +fig.tight_layout() + +plt.show() diff --git a/testsuite/pytests/sli2py_regressions/test_issue_77.py b/testsuite/pytests/sli2py_regressions/test_issue_77.py index 651a85a77a..ba1bb8c151 100644 --- a/testsuite/pytests/sli2py_regressions/test_issue_77.py +++ b/testsuite/pytests/sli2py_regressions/test_issue_77.py @@ -61,6 +61,11 @@ "eprop_readout_bsshslm_2020", # does not send spikes "eprop_iaf_bsshslm_2020", # does not support stdp synapses "eprop_iaf_adapt_bsshslm_2020", # does not support stdp synapses + "eprop_readout", # does not send spikes + "eprop_iaf", # does not support stdp synapses + "eprop_iaf_adapt", # does not support stdp synapses + "eprop_iaf_psc_delta", # does not support stdp synapses + "eprop_iaf_psc_delta_adapt", # does not support stdp synapses ] # The following models require connections to rport 1 or other specific parameters: diff --git a/testsuite/pytests/test_eprop_bsshslm_2020_plasticity.py b/testsuite/pytests/test_eprop_bsshslm_2020_plasticity.py index 0f65167d62..f8daaf7fe8 100644 --- a/testsuite/pytests/test_eprop_bsshslm_2020_plasticity.py +++ b/testsuite/pytests/test_eprop_bsshslm_2020_plasticity.py @@ -41,7 +41,9 @@ def fix_resolution(): @pytest.mark.parametrize("source_model", supported_source_models) @pytest.mark.parametrize("target_model", supported_target_models) def test_connect_with_eprop_synapse(source_model, target_model): - """Ensures that the restriction to supported neuron models works.""" + """ + Ensure that the restriction to supported neuron models works. + """ # Connect supported models with e-prop synapse src = nest.Create(source_model) @@ -51,7 +53,9 @@ def test_connect_with_eprop_synapse(source_model, target_model): @pytest.mark.parametrize("target_model", set(nest.node_models) - set(supported_target_models)) def test_unsupported_model_raises(target_model): - """Confirm that connecting a non-eprop neuron as target via an eprop_synapse_bsshslm_2020 raises an error.""" + """ + Confirm that connecting a non-eprop neuron as target via an eprop_synapse_bsshslm_2020 raises an error. + """ src_nrn = nest.Create(supported_source_models[0]) tgt_nrn = nest.Create(target_model) @@ -62,21 +66,23 @@ def test_unsupported_model_raises(target_model): def test_eprop_regression(): """ - Test correct computation of losses for a regression task - (for details on the task, see nest-simulator/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves.py) + Test correct computation of losses for a regression task (for details on the task, see + nest-simulator/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves_bsshslm_2020.py) by comparing the simulated losses with - 1. NEST reference losses to catch scenarios in which the e-prop model does not work as intended (e.g., - potential future changes to the NEST code base or a faulty installation). These reference losses - were obtained from a simulation with the verified NEST e-prop implementation run with - Linux 4.15.0-213-generic, Python v3.11.6, Numpy v1.26.0, and NEST@3304c6b5c. - - 2. TensorFlow reference losses to check the faithfulness to the original model. These reference losses were - obtained from a simulation with the original TensorFlow implementation - (https://github.com/INM-6/eligibility_propagation/blob/eprop_in_nest/Figure_3_and_S7_e_prop_tutorials/tutorial_pattern_generation.py, - a modified fork of the original model at https://github.com/IGITUGraz/eligibility_propagation) run with - Linux 4.15.0-213-generic, Python v3.6.10, Numpy v1.18.0, TensorFlow v1.15.0, and - INM6/eligibility_propagation@7df7d2627. + 1. NEST reference losses to catch scenarios in which the e-prop model does not work as + intended (e.g., potential future changes to the NEST code base or a faulty installation). + These reference losses were obtained from a simulation with the verified NEST e-prop + implementation run with Linux 4.15.0-213-generic, Python v3.11.6, Numpy v1.26.0, and + NEST@3304c6b5c. + + 2. TensorFlow reference losses to check the faithfulness to the original model. These + reference losses were obtained from a simulation with the original TensorFlow implementation + (https://github.com/INM-6/eligibility_propagation/blob/eprop_in_nest/Figure_3_and_S7_e_prop_tutorials/tutorial_pattern_generation.py, + a modified fork of the original model at + https://github.com/IGITUGraz/eligibility_propagation) run with Linux 4.15.0-213-generic, + Python v3.6.10, Numpy v1.18.0, TensorFlow v1.15.0, and + INM6/eligibility_propagation@7df7d2627. """ # pylint: disable=line-too-long # noqa: E501 # Initialize random generator @@ -85,7 +91,7 @@ def test_eprop_regression(): # Define timing of task - n_batch = 1 + batch_size = 1 n_iter = 5 steps = { @@ -93,7 +99,7 @@ def test_eprop_regression(): } steps["learning_window"] = steps["sequence"] - steps["task"] = n_iter * n_batch * steps["sequence"] + steps["task"] = n_iter * batch_size * steps["sequence"] steps.update( { @@ -105,9 +111,9 @@ def test_eprop_regression(): } ) - steps["total_offset"] = ( - steps["offset_gen"] + steps["delay_in_rec"] + steps["delay_rec_out"] + steps["delay_out_norm"] - ) + steps["delays"] = steps["delay_in_rec"] + steps["delay_rec_out"] + steps["delay_out_norm"] + + steps["total_offset"] = steps["offset_gen"] + steps["delays"] steps["sim"] = steps["task"] + steps["total_offset"] + steps["extension_sim"] @@ -135,31 +141,35 @@ def test_eprop_regression(): n_rec = 100 n_out = 1 - params_nrn_rec = { + params_nrn_out = { "C_m": 1.0, - "c_reg": 300.0, - "gamma": 0.3, "E_L": 0.0, - "f_target": 10.0, "I_e": 0.0, + "loss": "mean_squared_error", "regular_spike_arrival": False, - "surrogate_gradient_function": "piecewise_linear", - "t_ref": 0.0, "tau_m": 30.0, "V_m": 0.0, - "V_th": 0.03, } - params_nrn_out = { + params_nrn_rec = { + "beta": 1.0, "C_m": 1.0, + "c_reg": 300.0, "E_L": 0.0, + "f_target": 10.0, + "gamma": 0.3, "I_e": 0.0, - "loss": "mean_squared_error", "regular_spike_arrival": False, + "surrogate_gradient_function": "piecewise_linear", + "t_ref": 0.0, "tau_m": 30.0, "V_m": 0.0, + "V_th": 0.03, } + params_nrn_rec["gamma"] /= params_nrn_rec["V_th"] + params_nrn_rec["beta"] /= np.abs(params_nrn_rec["V_th"]) + gen_spk_in = nest.Create("spike_generator", n_in) nrns_in = nest.Create("parrot_neuron", n_in) nrns_rec = nest.Create("eprop_iaf_bsshslm_2020", n_rec, params_nrn_rec) @@ -172,25 +182,34 @@ def test_eprop_regression(): n_record_w = 1 params_mm_rec = { + "interval": duration["sequence"], "record_from": ["V_m", "surrogate_gradient", "learning_signal"], "start": duration["offset_gen"] + duration["delay_in_rec"], - "interval": duration["sequence"], + "stop": duration["offset_gen"] + duration["delay_in_rec"] + duration["task"], } params_mm_out = { + "interval": duration["step"], "record_from": ["V_m", "readout_signal", "readout_signal_unnorm", "target_signal", "error_signal"], "start": duration["total_offset"], - "interval": duration["step"], + "stop": duration["total_offset"] + duration["task"], } params_wr = { "senders": nrns_in[:n_record_w] + nrns_rec[:n_record_w], "targets": nrns_rec[:n_record_w] + nrns_out, + "start": duration["total_offset"], + "stop": duration["total_offset"] + duration["task"], + } + + params_sr = { + "start": duration["offset_gen"], + "stop": duration["total_offset"] + duration["task"], } mm_rec = nest.Create("multimeter", params_mm_rec) mm_out = nest.Create("multimeter", params_mm_out) - sr = nest.Create("spike_recorder") + sr = nest.Create("spike_recorder", params_sr) wr = nest.Create("weight_recorder", params_wr) nrns_rec_record = nrns_rec[:n_record] @@ -210,35 +229,29 @@ def test_eprop_regression(): params_common_syn_eprop = { "optimizer": { "type": "gradient_descent", - "batch_size": n_batch, + "batch_size": batch_size, "eta": 1e-4, "Wmin": -100.0, "Wmax": 100.0, }, - "weight_recorder": wr, "average_gradient": False, + "weight_recorder": wr, } - params_syn_in = { + params_syn_base = { "synapse_model": "eprop_synapse_bsshslm_2020", "delay": duration["step"], "tau_m_readout": params_nrn_out["tau_m"], - "weight": weights_in_rec, } - params_syn_rec = { - "synapse_model": "eprop_synapse_bsshslm_2020", - "delay": duration["step"], - "tau_m_readout": params_nrn_out["tau_m"], - "weight": weights_rec_rec, - } + params_syn_in = params_syn_base.copy() + params_syn_in["weight"] = weights_in_rec - params_syn_out = { - "synapse_model": "eprop_synapse_bsshslm_2020", - "delay": duration["step"], - "tau_m_readout": params_nrn_out["tau_m"], - "weight": weights_rec_out, - } + params_syn_rec = params_syn_base.copy() + params_syn_rec["weight"] = weights_rec_rec + + params_syn_out = params_syn_base.copy() + params_syn_out["weight"] = weights_rec_out params_syn_feedback = { "synapse_model": "eprop_learning_signal_connection_bsshslm_2020", @@ -276,14 +289,13 @@ def test_eprop_regression(): input_spike_prob = 0.05 dtype_in_spks = np.float32 - input_spike_bools = np.random.rand(n_batch, steps["sequence"], n_in) < input_spike_prob - input_spike_bools = np.hstack(input_spike_bools.swapaxes(1, 2)) + input_spike_bools = (np.random.rand(steps["sequence"], n_in) < input_spike_prob).swapaxes(0, 1) input_spike_bools[:, 0] = 0 sequence_starts = np.arange(0.0, duration["task"], duration["sequence"]) + duration["offset_gen"] params_gen_spk_in = [] for input_spike_bool in input_spike_bools: - input_spike_times = np.arange(0.0, duration["sequence"] * n_batch, duration["step"])[input_spike_bool] + input_spike_times = np.arange(0.0, duration["sequence"], duration["step"])[input_spike_bool] input_spike_times_all = [input_spike_times + start for start in sequence_starts] params_gen_spk_in.append({"spike_times": np.hstack(input_spike_times_all).astype(dtype_in_spks)}) @@ -311,7 +323,7 @@ def generate_superimposed_sines(steps_sequence, periods): params_gen_rate_target = { "amplitude_times": np.arange(0.0, duration["task"], duration["step"]) + duration["total_offset"], - "amplitude_values": np.tile(target_signal, n_iter * n_batch), + "amplitude_values": np.tile(target_signal, n_iter * batch_size), } nest.SetStatus(gen_rate_target, params_gen_rate_target) @@ -328,23 +340,26 @@ def generate_superimposed_sines(steps_sequence, periods): readout_signal = events_mm_out["readout_signal"] target_signal = events_mm_out["target_signal"] + senders = events_mm_out["senders"] + + readout_signal = np.array([readout_signal[senders == i] for i in set(senders)]) + target_signal = np.array([target_signal[senders == i] for i in set(senders)]) - error = (readout_signal - target_signal) ** 2 - loss = 0.5 * np.add.reduceat(error, np.arange(0, steps["task"], steps["sequence"])) + readout_signal = readout_signal.reshape((n_out, n_iter, batch_size, steps["sequence"])) + target_signal = target_signal.reshape((n_out, n_iter, batch_size, steps["sequence"])) - # Verify results + loss = 0.5 * np.mean(np.sum((readout_signal - target_signal) ** 2, axis=3), axis=(0, 2)) - loss_NEST_reference = np.array( - [ - 101.964356999041, - 103.466731126205, - 103.340607074771, - 103.680244037686, - 104.412775748752, - ] - ) + # Verify results + loss_nest_reference = [ + 101.964356999041, + 103.466731126205, + 103.340607074771, + 103.680244037686, + 104.412775748752, + ] - loss_TF_reference = np.array( + loss_tf_reference = np.array( [ 101.964363098144, 103.466735839843, @@ -354,27 +369,54 @@ def generate_superimposed_sines(steps_sequence, periods): ] ) - assert np.allclose(loss, loss_NEST_reference, rtol=1e-8) - assert np.allclose(loss, loss_TF_reference, rtol=1e-7) - - -def test_eprop_classification(): + assert np.allclose(loss, loss_tf_reference, rtol=1e-7) + assert np.allclose(loss, loss_nest_reference, rtol=1e-8) + + +@pytest.mark.parametrize( + "batch_size,loss_nest_reference", + [ + ( + 1, + [ + 0.741152550006, + 0.740388187700, + 0.665785233177, + 0.663644193322, + 0.729428962844, + ], + ), + ( + 2, + [ + 0.702163370672, + 0.735555303152, + 0.740354864111, + 0.683882815282, + 0.707841122268, + ], + ), + ], +) +def test_eprop_classification(batch_size, loss_nest_reference): """ - Test correct computation of losses for a classification task - (for details on the task, see nest-simulator/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation.py) + Test correct computation of losses for a classification task (for details on the task, see + nest-simulator/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation_bsshslm_2020.py) by comparing the simulated losses with - 1. NEST reference losses to catch scenarios in which the e-prop model does not work as intended (e.g., - potential future changes to the NEST code base or a faulty installation). These reference losses - were obtained from a simulation with the verified NEST e-prop implementation run with - Linux 4.15.0-213-generic, Python v3.11.6, Numpy v1.26.0, and NEST@3304c6b5c. - - 2. TensorFlow reference losses to check the faithfulness to the original model. These reference losses were - obtained from a simulation with the original TensorFlow implementation - (https://github.com/INM-6/eligibility_propagation/blob/eprop_in_nest/Figure_3_and_S7_e_prop_tutorials/tutorial_evidence_accumulation_with_alif.py, - a modified fork of the original model at https://github.com/IGITUGraz/eligibility_propagation) run with - Linux 4.15.0-213-generic, Python v3.6.10, Numpy v1.18.0, TensorFlow v1.15.0, and - INM6/eligibility_propagation@7df7d2627. + 1. NEST reference losses to catch scenarios in which the e-prop model does not work as + intended (e.g., potential future changes to the NEST code base or a faulty installation). + These reference losses were obtained from a simulation with the verified NEST e-prop + implementation run with Linux 4.15.0-213-generic, Python v3.11.6, Numpy v1.26.0, and + NEST@3304c6b5c. + + 2. TensorFlow reference losses to check the faithfulness to the original model. These + reference losses were obtained from a simulation with the original TensorFlow implementation + (https://github.com/INM-6/eligibility_propagation/blob/eprop_in_nest/Figure_3_and_S7_e_prop_tutorials/tutorial_evidence_accumulation_with_alif.py, + a modified fork of the original model at + https://github.com/IGITUGraz/eligibility_propagation) run with Linux 4.15.0-213-generic, + Python v3.6.10, Numpy v1.18.0, TensorFlow v1.15.0, and + INM6/eligibility_propagation@7df7d2627. """ # pylint: disable=line-too-long # noqa: E501 # Initialize random generator @@ -384,12 +426,14 @@ def test_eprop_classification(): # Define timing of task - n_batch = 1 n_iter = 5 - n_input_symbols = 4 - n_cues = 7 - prob_group = 0.3 + input = { + "n_symbols": 4, + "n_cues": 7, + "prob_group": 0.3, + "spike_prob": 0.04, + } steps = { "cue": 100, @@ -398,10 +442,10 @@ def test_eprop_classification(): "recall": 150, } - steps["cues"] = n_cues * (steps["cue"] + steps["spacing"]) + steps["cues"] = input["n_cues"] * (steps["cue"] + steps["spacing"]) steps["sequence"] = steps["cues"] + steps["bg_noise"] + steps["recall"] steps["learning_window"] = steps["recall"] - steps["task"] = n_iter * n_batch * steps["sequence"] + steps["task"] = n_iter * batch_size * steps["sequence"] steps.update( { @@ -413,9 +457,9 @@ def test_eprop_classification(): } ) - steps["total_offset"] = ( - steps["offset_gen"] + steps["delay_in_rec"] + steps["delay_rec_out"] + steps["delay_out_norm"] - ) + steps["delays"] = steps["delay_in_rec"] + steps["delay_rec_out"] + steps["delay_out_norm"] + + steps["total_offset"] = steps["offset_gen"] + steps["delays"] steps["sim"] = steps["task"] + steps["total_offset"] + steps["extension_sim"] @@ -445,9 +489,20 @@ def test_eprop_classification(): n_rec = n_ad + n_reg n_out = 2 + params_nrn_out = { + "C_m": 1.0, + "E_L": 0.0, + "I_e": 0.0, + "loss": "cross_entropy", + "regular_spike_arrival": False, + "tau_m": 20.0, + "V_m": 0.0, + } + params_nrn_reg = { + "beta": 1.0, "C_m": 1.0, - "c_reg": 2.0, + "c_reg": 300.0, "E_L": 0.0, "f_target": 10.0, "gamma": 0.3, @@ -460,11 +515,15 @@ def test_eprop_classification(): "V_th": 0.6, } + params_nrn_reg["gamma"] /= params_nrn_reg["V_th"] + params_nrn_reg["beta"] /= np.abs(params_nrn_reg["V_th"]) + params_nrn_ad = { + "beta": 1.0, "adapt_tau": 2000.0, "adaptation": 0.0, "C_m": 1.0, - "c_reg": 2.0, + "c_reg": 300.0, "E_L": 0.0, "f_target": 10.0, "gamma": 0.3, @@ -477,19 +536,13 @@ def test_eprop_classification(): "V_th": 0.6, } - params_nrn_ad["adapt_beta"] = ( - 1.7 * (1.0 - np.exp(-1.0 / params_nrn_ad["adapt_tau"])) / (1.0 - np.exp(-1.0 / params_nrn_ad["tau_m"])) - ) + params_nrn_ad["gamma"] /= params_nrn_ad["V_th"] + params_nrn_ad["beta"] /= np.abs(params_nrn_ad["V_th"]) - params_nrn_out = { - "C_m": 1.0, - "E_L": 0.0, - "I_e": 0.0, - "loss": "cross_entropy", - "regular_spike_arrival": False, - "tau_m": 20.0, - "V_m": 0.0, - } + params_nrn_ad["adapt_beta"] = 1.7 * ( + (1.0 - np.exp(-duration["step"] / params_nrn_ad["adapt_tau"])) + / (1.0 - np.exp(-duration["step"] / params_nrn_ad["tau_m"])) + ) gen_spk_in = nest.Create("spike_generator", n_in) nrns_in = nest.Create("parrot_neuron", n_in) @@ -505,29 +558,47 @@ def test_eprop_classification(): n_record = 1 n_record_w = 1 - params_mm_rec = { + params_mm_reg = { + "interval": duration["step"], "record_from": ["V_m", "surrogate_gradient", "learning_signal"], "start": duration["offset_gen"] + duration["delay_in_rec"], - "interval": duration["sequence"], + "stop": duration["offset_gen"] + duration["delay_in_rec"] + duration["task"], + } + + params_mm_ad = { + "interval": duration["step"], + "record_from": params_mm_reg["record_from"] + ["V_th_adapt", "adaptation"], + "start": duration["offset_gen"] + duration["delay_in_rec"], + "stop": duration["offset_gen"] + duration["delay_in_rec"] + duration["task"], } params_mm_out = { + "interval": duration["step"], "record_from": ["V_m", "readout_signal", "readout_signal_unnorm", "target_signal", "error_signal"], "start": duration["total_offset"], - "interval": duration["step"], + "stop": duration["total_offset"] + duration["task"], } params_wr = { "senders": nrns_in[:n_record_w] + nrns_rec[:n_record_w], "targets": nrns_rec[:n_record_w] + nrns_out, + "start": duration["total_offset"], + "stop": duration["total_offset"] + duration["task"], } - mm_rec = nest.Create("multimeter", params_mm_rec) + params_sr = { + "start": duration["offset_gen"], + "stop": duration["total_offset"] + duration["task"], + } + + mm_reg = nest.Create("multimeter", params_mm_reg) + mm_ad = nest.Create("multimeter", params_mm_ad) mm_out = nest.Create("multimeter", params_mm_out) - sr = nest.Create("spike_recorder") + sr = nest.Create("spike_recorder", params_sr) wr = nest.Create("weight_recorder", params_wr) - nrns_rec_record = nrns_rec[:n_record] + nrns_reg_record = nrns_reg[:n_record] + nrns_ad_record = nrns_ad[:n_record] # Create connections @@ -550,7 +621,7 @@ def calculate_glorot_dist(fan_in, fan_out): params_common_syn_eprop = { "optimizer": { "type": "adam", - "batch_size": n_batch, + "batch_size": batch_size, "beta_1": 0.9, "beta_2": 0.999, "epsilon": 1e-8, @@ -558,30 +629,24 @@ def calculate_glorot_dist(fan_in, fan_out): "Wmin": -100.0, "Wmax": 100.0, }, - "weight_recorder": wr, "average_gradient": True, + "weight_recorder": wr, } - params_syn_in = { + params_syn_base = { "synapse_model": "eprop_synapse_bsshslm_2020", "delay": duration["step"], "tau_m_readout": params_nrn_out["tau_m"], - "weight": weights_in_rec, } - params_syn_rec = { - "synapse_model": "eprop_synapse_bsshslm_2020", - "delay": duration["step"], - "tau_m_readout": params_nrn_out["tau_m"], - "weight": weights_rec_rec, - } + params_syn_in = params_syn_base.copy() + params_syn_in["weight"] = weights_in_rec - params_syn_out = { - "synapse_model": "eprop_synapse_bsshslm_2020", - "delay": duration["step"], - "tau_m_readout": params_nrn_out["tau_m"], - "weight": weights_rec_out, - } + params_syn_rec = params_syn_base.copy() + params_syn_rec["weight"] = weights_rec_rec + + params_syn_out = params_syn_base.copy() + params_syn_out["weight"] = weights_rec_out params_syn_feedback = { "synapse_model": "eprop_learning_signal_connection_bsshslm_2020", @@ -607,6 +672,13 @@ def calculate_glorot_dist(fan_in, fan_out): "delay": duration["step"], } + params_init_optimizer = { + "optimizer": { + "m": 0.0, + "v": 0.0, + } + } + nest.SetDefaults("eprop_synapse_bsshslm_2020", params_common_syn_eprop) nest.Connect(gen_spk_in, nrns_in, params_conn_one_to_one, params_syn_static) @@ -619,30 +691,31 @@ def calculate_glorot_dist(fan_in, fan_out): nest.Connect(nrns_in + nrns_rec, sr, params_conn_all_to_all, params_syn_static) - nest.Connect(mm_rec, nrns_rec_record, params_conn_all_to_all, params_syn_static) + nest.Connect(mm_reg, nrns_reg_record, params_conn_all_to_all, params_syn_static) + nest.Connect(mm_ad, nrns_ad_record, params_conn_all_to_all, params_syn_static) nest.Connect(mm_out, nrns_out, params_conn_all_to_all, params_syn_static) + nest.GetConnections(nrns_rec[0], nrns_rec[1:3]).set([params_init_optimizer] * 2) + # Create input and output - def generate_evidence_accumulation_input_output( - n_batch, n_in, prob_group, input_spike_prob, n_cues, n_input_symbols, steps - ): - n_pop_nrn = n_in // n_input_symbols + def generate_evidence_accumulation_input_output(batch_size, n_in, steps, input): + n_pop_nrn = n_in // input["n_symbols"] - prob_choices = np.array([prob_group, 1 - prob_group], dtype=np.float32) - idx = np.random.choice([0, 1], n_batch) - probs = np.zeros((n_batch, 2), dtype=np.float32) + prob_choices = np.array([input["prob_group"], 1 - input["prob_group"]], dtype=np.float32) + idx = np.random.choice([0, 1], batch_size) + probs = np.zeros((batch_size, 2), dtype=np.float32) probs[:, 0] = prob_choices[idx] probs[:, 1] = prob_choices[1 - idx] - batched_cues = np.zeros((n_batch, n_cues), dtype=int) - for b_idx in range(n_batch): - batched_cues[b_idx, :] = np.random.choice([0, 1], n_cues, p=probs[b_idx]) + batched_cues = np.zeros((batch_size, input["n_cues"]), dtype=int) + for b_idx in range(batch_size): + batched_cues[b_idx, :] = np.random.choice([0, 1], input["n_cues"], p=probs[b_idx]) - input_spike_probs = np.zeros((n_batch, steps["sequence"], n_in)) + input_spike_probs = np.zeros((batch_size, steps["sequence"], n_in)) - for b_idx in range(n_batch): - for c_idx in range(n_cues): + for b_idx in range(batch_size): + for c_idx in range(input["n_cues"]): cue = batched_cues[b_idx, c_idx] step_start = c_idx * (steps["cue"] + steps["spacing"]) + steps["spacing"] @@ -651,30 +724,27 @@ def generate_evidence_accumulation_input_output( pop_nrn_start = cue * n_pop_nrn pop_nrn_stop = pop_nrn_start + n_pop_nrn - input_spike_probs[b_idx, step_start:step_stop, pop_nrn_start:pop_nrn_stop] = input_spike_prob + input_spike_probs[b_idx, step_start:step_stop, pop_nrn_start:pop_nrn_stop] = input["spike_prob"] - input_spike_probs[:, -steps["recall"] :, 2 * n_pop_nrn : 3 * n_pop_nrn] = input_spike_prob - input_spike_probs[:, :, 3 * n_pop_nrn :] = input_spike_prob / 4.0 + input_spike_probs[:, -steps["recall"] :, 2 * n_pop_nrn : 3 * n_pop_nrn] = input["spike_prob"] + input_spike_probs[:, :, 3 * n_pop_nrn :] = input["spike_prob"] / 4.0 input_spike_bools = input_spike_probs > np.random.rand(input_spike_probs.size).reshape(input_spike_probs.shape) input_spike_bools[:, 0, :] = 0 - target_cues = np.zeros(n_batch, dtype=int) - target_cues[:] = np.sum(batched_cues, axis=1) > int(n_cues / 2) + target_cues = np.zeros(batch_size, dtype=int) + target_cues[:] = np.sum(batched_cues, axis=1) > int(input["n_cues"] / 2) return input_spike_bools, target_cues - input_spike_prob = 0.04 dtype_in_spks = np.float32 input_spike_bools_list = [] target_cues_list = [] - for iteration in range(n_iter): - input_spike_bools, target_cues = generate_evidence_accumulation_input_output( - n_batch, n_in, prob_group, input_spike_prob, n_cues, n_input_symbols, steps - ) + for _ in range(n_iter): + input_spike_bools, target_cues = generate_evidence_accumulation_input_output(batch_size, n_in, steps, input) input_spike_bools_list.append(input_spike_bools) - target_cues_list.extend(target_cues.tolist()) + target_cues_list.extend(target_cues) input_spike_bools_arr = np.array(input_spike_bools_list).reshape(steps["task"], n_in) timeline_task = np.arange(0.0, duration["task"], duration["step"]) + duration["offset_gen"] @@ -684,8 +754,8 @@ def generate_evidence_accumulation_input_output( for nrn_in_idx in range(n_in) ] - target_rate_changes = np.zeros((n_out, n_batch * n_iter)) - target_rate_changes[np.array(target_cues_list), np.arange(n_batch * n_iter)] = 1 + target_rate_changes = np.zeros((n_out, batch_size * n_iter)) + target_rate_changes[np.array(target_cues_list), np.arange(batch_size * n_iter)] = 1 params_gen_rate_target = [ { @@ -715,27 +785,17 @@ def generate_evidence_accumulation_input_output( readout_signal = np.array([readout_signal[senders == i] for i in set(senders)]) target_signal = np.array([target_signal[senders == i] for i in set(senders)]) - readout_signal = readout_signal.reshape((n_out, n_iter, n_batch, steps["sequence"])) - readout_signal = readout_signal[:, :, :, -steps["learning_window"] :] + readout_signal = readout_signal.reshape((n_out, n_iter, batch_size, steps["sequence"])) + target_signal = target_signal.reshape((n_out, n_iter, batch_size, steps["sequence"])) - target_signal = target_signal.reshape((n_out, n_iter, n_batch, steps["sequence"])) + readout_signal = readout_signal[:, :, :, -steps["learning_window"] :] target_signal = target_signal[:, :, :, -steps["learning_window"] :] loss = -np.mean(np.sum(target_signal * np.log(readout_signal), axis=0), axis=(1, 2)) # Verify results - loss_NEST_reference = np.array( - [ - 0.741152550006, - 0.740388187700, - 0.665785233177, - 0.663644193322, - 0.729428962844, - ] - ) - - loss_TF_reference = np.array( + loss_tf_reference = np.array( [ 0.741152524948, 0.740388214588, @@ -745,5 +805,6 @@ def generate_evidence_accumulation_input_output( ] ) - assert np.allclose(loss, loss_NEST_reference, rtol=1e-8) - assert np.allclose(loss, loss_TF_reference, rtol=1e-6) + if batch_size == 1: + assert np.allclose(loss, loss_tf_reference, rtol=1e-6) + assert np.allclose(loss, loss_nest_reference, rtol=1e-8) diff --git a/testsuite/pytests/test_eprop_plasticity.py b/testsuite/pytests/test_eprop_plasticity.py new file mode 100644 index 0000000000..c355ac75b1 --- /dev/null +++ b/testsuite/pytests/test_eprop_plasticity.py @@ -0,0 +1,552 @@ +# -*- coding: utf-8 -*- +# +# test_eprop_plasticity.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +""" +Test functionality of e-prop plasticity. +""" + +import nest +import numpy as np +import pytest + +nest.set_verbosity("M_WARNING") + +supported_source_models = ["eprop_iaf", "eprop_iaf_adapt", "eprop_iaf_psc_delta", "eprop_iaf_psc_delta_adapt"] +supported_target_models = supported_source_models + ["eprop_readout"] + + +@pytest.fixture(autouse=True) +def fix_resolution(): + nest.ResetKernel() + + +@pytest.mark.parametrize("source_model", supported_source_models) +@pytest.mark.parametrize("target_model", supported_target_models) +def test_connect_with_eprop_synapse(source_model, target_model): + """ + Ensure that the restriction to supported neuron models works. + """ + + # Connect supported models with e-prop synapse + src = nest.Create(source_model) + tgt = nest.Create(target_model) + nest.Connect(src, tgt, "all_to_all", {"synapse_model": "eprop_synapse", "delay": nest.resolution}) + + +@pytest.mark.parametrize("target_model", set(nest.node_models) - set(supported_target_models)) +def test_unsupported_model_raises(target_model): + """ + Confirm that connecting a non-eprop neuron as target via an eprop_synapse raises an error. + """ + + src_nrn = nest.Create(supported_source_models[0]) + tgt_nrn = nest.Create(target_model) + + with pytest.raises(nest.kernel.NESTError): + nest.Connect(src_nrn, tgt_nrn, "all_to_all", {"synapse_model": "eprop_synapse"}) + + +@pytest.mark.parametrize( + "neuron_model,optimizer,loss_nest_reference", + [ + ( + "eprop_iaf", + "adam", + [ + 0.13126137747586, + 0.09395562983704, + 0.00734052541014, + 0.02909589949313, + 0.00279041902009, + ], + ), + ( + "eprop_iaf_adapt", + "gradient_descent", + [ + 0.04298221363883, + 0.03100545785399, + 0.00930311104052, + 0.00455478436740, + 0.00017408818078, + ], + ), + ( + "eprop_iaf_psc_delta", + "gradient_descent", + [ + 0.32286231964124, + 0.61322219696014, + 0.63745062813969, + 0.63844466107304, + 0.58671835471489, + ], + ), + ( + "eprop_iaf_psc_delta_adapt", + "gradient_descent", + [ + 0.19603671513741, + 0.33370485743782, + 0.35727428693343, + 0.31206408953001, + 0.31411885659561, + ], + ), + ], +) +def test_eprop_regression(neuron_model, optimizer, loss_nest_reference): + """ + Test correct computation of losses for a regression task (for details on the task, see + nest-simulator/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves.py) by + comparing the simulated losses with NEST reference losses to catch scenarios in which the e-prop + model does not work as intended (e.g., potential future changes to the NEST code base or a + faulty installation). These reference losses were obtained from a simulation with the verified + NEST e-prop implementation run with Linux 6.5.0-28-generic, Python v3.12.3, Numpy v1.26.4, and + NEST@9b65de4bf. + """ + + # Initialize random generator + rng_seed = 1 + np.random.seed(rng_seed) + + # Define timing of task + + group_size = 1 + n_iter = 5 + + steps = { + "sequence": 100, + } + + steps["learning_window"] = steps["sequence"] + steps["task"] = n_iter * group_size * steps["sequence"] + + steps.update( + { + "offset_gen": 1, + "delay_in_rec": 1, + "extension_sim": 3, + } + ) + + steps["delays"] = steps["delay_in_rec"] + + steps["total_offset"] = steps["offset_gen"] + steps["delays"] + + steps["sim"] = steps["task"] + steps["total_offset"] + steps["extension_sim"] + + duration = {"step": 1.0} + + duration.update({key: value * duration["step"] for key, value in steps.items()}) + + # Set up simulation + + params_setup = { + "print_time": False, + "resolution": duration["step"], + "total_num_virtual_procs": 1, + } + + nest.ResetKernel() + nest.set(**params_setup) + + # Create neurons + + n_in = 100 + n_rec = 100 + n_out = 1 + + params_nrn_out = { + "C_m": 1.0, + "E_L": 0.0, + "eprop_isi_trace_cutoff": 100, + "I_e": 0.0, + "regular_spike_arrival": False, + "tau_m": 30.0, + "V_m": 0.0, + } + + params_nrn_rec = { + "beta": 33.3, + "C_m": 1.0, + "c_reg": 300.0 / duration["sequence"], + "E_L": 0.0, + "eprop_isi_trace_cutoff": 100, + "f_target": 10.0, + "gamma": 10.0, + "I_e": 0.0, + "kappa": 0.97, + "kappa_reg": 0.97, + "regular_spike_arrival": False, + "surrogate_gradient_function": "piecewise_linear", + "t_ref": 0.0, + "tau_m": 30.0, + "V_m": 0.0, + "V_th": 0.03, + } + + if neuron_model in ["eprop_iaf_psc_delta", "eprop_iaf_psc_delta_adapt"]: + del params_nrn_rec["regular_spike_arrival"] + params_nrn_rec["V_reset"] = -0.5 + params_nrn_rec["c_reg"] = 2.0 / duration["sequence"] + params_nrn_rec["V_th"] = 0.5 + elif neuron_model == "eprop_iaf_adapt": + params_nrn_rec["adapt_beta"] = 0.0174 + params_nrn_rec["adapt_tau"] = 2000.0 + params_nrn_rec["adaptation"] = 0.0 + + gen_spk_in = nest.Create("spike_generator", n_in) + nrns_in = nest.Create("parrot_neuron", n_in) + nrns_rec = nest.Create(neuron_model, n_rec, params_nrn_rec) + nrns_out = nest.Create("eprop_readout", n_out, params_nrn_out) + gen_rate_target = nest.Create("step_rate_generator", n_out) + gen_learning_window = nest.Create("step_rate_generator") + + # Create recorders + + n_record = 1 + n_record_w = 1 + + params_mm_rec = { + "interval": duration["step"], + "record_from": ["V_m", "surrogate_gradient", "learning_signal"], + "start": duration["offset_gen"] + duration["delay_in_rec"], + "stop": duration["offset_gen"] + duration["delay_in_rec"] + duration["task"], + } + + params_mm_out = { + "interval": duration["step"], + "record_from": ["V_m", "readout_signal", "target_signal", "error_signal"], + "start": duration["total_offset"], + "stop": duration["total_offset"] + duration["task"], + } + + params_wr = { + "senders": nrns_in[:n_record_w] + nrns_rec[:n_record_w], + "targets": nrns_rec[:n_record_w] + nrns_out, + "start": duration["total_offset"], + "stop": duration["total_offset"] + duration["task"], + } + + params_sr = { + "start": duration["offset_gen"], + "stop": duration["total_offset"] + duration["task"], + } + + mm_rec = nest.Create("multimeter", params_mm_rec) + mm_out = nest.Create("multimeter", params_mm_out) + sr = nest.Create("spike_recorder", params_sr) + wr = nest.Create("weight_recorder", params_wr) + + nrns_rec_record = nrns_rec[:n_record] + + # Create connections + + params_conn_all_to_all = {"rule": "all_to_all", "allow_autapses": False} + params_conn_one_to_one = {"rule": "one_to_one"} + + dtype_weights = np.float32 + weights_in_rec = np.array(np.random.randn(n_in, n_rec).T / np.sqrt(n_in), dtype=dtype_weights) + weights_rec_rec = np.array(np.random.randn(n_rec, n_rec).T / np.sqrt(n_rec), dtype=dtype_weights) + np.fill_diagonal(weights_rec_rec, 0.0) + weights_rec_out = np.array(np.random.randn(n_rec, n_out).T / np.sqrt(n_rec), dtype=dtype_weights) + weights_out_rec = np.array(np.random.randn(n_rec, n_out) / np.sqrt(n_rec), dtype=dtype_weights) + + params_common_syn_eprop = { + "optimizer": { + "type": optimizer, + "batch_size": 1, + "eta": 1e-4, + "optimize_each_step": True, + "Wmin": -100.0, + "Wmax": 100.0, + }, + "weight_recorder": wr, + } + + if optimizer == "adam": + params_common_syn_eprop["optimizer"]["beta_1"] = 0.9 + params_common_syn_eprop["optimizer"]["beta_2"] = 0.999 + params_common_syn_eprop["optimizer"]["epsilon"] = 1e-7 + + params_syn_base = { + "synapse_model": "eprop_synapse", + "delay": duration["step"], + } + + params_syn_in = params_syn_base.copy() + params_syn_in["weight"] = weights_in_rec + + params_syn_rec = params_syn_base.copy() + params_syn_rec["weight"] = weights_rec_rec + + params_syn_out = params_syn_base.copy() + params_syn_out["weight"] = weights_rec_out + + params_syn_feedback = { + "synapse_model": "eprop_learning_signal_connection", + "delay": duration["step"], + "weight": weights_out_rec, + } + + params_syn_learning_window = { + "synapse_model": "rate_connection_delayed", + "delay": duration["step"], + "receptor_type": 1, + } + + params_syn_rate_target = { + "synapse_model": "rate_connection_delayed", + "delay": duration["step"], + "receptor_type": 2, + } + + params_syn_static = { + "synapse_model": "static_synapse", + "delay": duration["step"], + } + + nest.SetDefaults("eprop_synapse", params_common_syn_eprop) + + nest.Connect(gen_spk_in, nrns_in, params_conn_one_to_one, params_syn_static) + nest.Connect(nrns_in, nrns_rec, params_conn_all_to_all, params_syn_in) + nest.Connect(nrns_rec, nrns_rec, params_conn_all_to_all, params_syn_rec) + nest.Connect(nrns_rec, nrns_out, params_conn_all_to_all, params_syn_out) + nest.Connect(nrns_out, nrns_rec, params_conn_all_to_all, params_syn_feedback) + nest.Connect(gen_rate_target, nrns_out, params_conn_one_to_one, params_syn_rate_target) + nest.Connect(gen_learning_window, nrns_out, params_conn_all_to_all, params_syn_learning_window) + + nest.Connect(nrns_in + nrns_rec, sr, params_conn_all_to_all, params_syn_static) + + nest.Connect(mm_rec, nrns_rec_record, params_conn_all_to_all, params_syn_static) + nest.Connect(mm_out, nrns_out, params_conn_all_to_all, params_syn_static) + + # Create input + + input_spike_prob = 0.05 + dtype_in_spks = np.float32 + + input_spike_bools = (np.random.rand(steps["sequence"], n_in) < input_spike_prob).swapaxes(0, 1) + + sequence_starts = np.arange(0.0, duration["task"], duration["sequence"]) + duration["offset_gen"] + params_gen_spk_in = [] + for input_spike_bool in input_spike_bools: + input_spike_times = np.arange(0.0, duration["sequence"], duration["step"])[input_spike_bool] + input_spike_times_all = [input_spike_times + start for start in sequence_starts] + params_gen_spk_in.append({"spike_times": np.hstack(input_spike_times_all).astype(dtype_in_spks)}) + + nest.SetStatus(gen_spk_in, params_gen_spk_in) + + # Create output + + def generate_superimposed_sines(steps_sequence, periods): + n_sines = len(periods) + + amplitudes = np.random.uniform(low=0.5, high=2.0, size=n_sines) + phases = np.random.uniform(low=0.0, high=2.0 * np.pi, size=n_sines) + + sines = [ + A * np.sin(np.linspace(phi, phi + 2.0 * np.pi * (steps_sequence // T), steps_sequence)) + for A, phi, T in zip(amplitudes, phases, periods) + ] + + superposition = sum(sines) + superposition -= superposition[0] + superposition /= max(np.abs(superposition).max(), 1e-6) + return superposition + + target_signal = generate_superimposed_sines(steps["sequence"], [1000, 500, 333, 200]) + + params_gen_rate_target = { + "amplitude_times": np.arange(0.0, duration["task"], duration["step"]) + duration["total_offset"], + "amplitude_values": np.tile(target_signal, n_iter * group_size), + } + + nest.SetStatus(gen_rate_target, params_gen_rate_target) + + # Create learning window + + params_gen_learning_window = { + "amplitude_times": [duration["total_offset"]], + "amplitude_values": [1.0], + } + + nest.SetStatus(gen_learning_window, params_gen_learning_window) + + # Simulate + + nest.Simulate(duration["sim"]) + + # Read out recorders + + events_mm_out = mm_out.get("events") + + # Evaluate training error + + readout_signal = events_mm_out["readout_signal"] + target_signal = events_mm_out["target_signal"] + senders = events_mm_out["senders"] + + readout_signal = np.array([readout_signal[senders == i] for i in set(senders)]) + target_signal = np.array([target_signal[senders == i] for i in set(senders)]) + + readout_signal = readout_signal.reshape((n_out, n_iter, group_size, steps["sequence"])) + target_signal = target_signal.reshape((n_out, n_iter, group_size, steps["sequence"])) + + loss = 0.5 * np.mean(np.sum((readout_signal - target_signal) ** 2, axis=3), axis=(0, 2)) + + # Verify results + + assert np.allclose(loss, loss_nest_reference, rtol=1e-8) + + +def test_unsupported_surrogate_gradient(): + """ + Confirm that selecting an unsupported surrogate gradient raises an error. + """ + + params_nrn_rec = { + "surrogate_gradient_function": "unsupported_surrogate_gradient", + } + + with pytest.raises(nest.kernel.NESTError): + nrn = nest.Create("eprop_iaf", 1, params_nrn_rec) + nest.Simulate(1.0) + + +@pytest.mark.parametrize( + "surrogate_gradient_type,surrogate_gradient_reference", + [ + ( + "piecewise_linear", + [ + 0.06135126216450, + 0.05456129183053, + 0.04841747260500, + 0.04285831508010, + 0.03782818133881, + ], + ), + ( + "exponential", + [ + 0.20795269433458, + 0.20514779735629, + 0.20264243938260, + 0.20040187562686, + 0.19839588646645, + ], + ), + ( + "fast_sigmoid_derivative", + [ + 0.14187432621036, + 0.13984381221999, + 0.13804386008020, + 0.13644497667035, + 0.13502206566839, + ], + ), + ( + "arctan", + [ + 0.01851467830587, + 0.01801794405092, + 0.01758480869073, + 0.01720567699047, + 0.01687268127553, + ], + ), + ], +) +def test_eprop_surrogate_gradients(surrogate_gradient_type, surrogate_gradient_reference): + """ + Test correct computation of surrogate gradients by comparing the simulated surrogate gradients + with NEST reference surrogate gradients. These reference surrogate gradients were obtained from + a simulation with the verified NEST e-prop implementation run with Linux 5.8.7-1-default, Python + v3.12.5, Numpy v2.0.1, and NEST@d04fe550d. + """ + + rng_seed = 1 + np.random.seed(rng_seed) + + duration = { + "step": 1.0, + "sim": 20.0, + } + + params_setup = { + "print_time": False, + "resolution": duration["step"], + "total_num_virtual_procs": 1, + } + + nest.ResetKernel() + nest.set(**params_setup) + + params_nrn_rec = { + "beta": 1.7, + "C_m": 1.0, + "c_reg": 0.0, + "E_L": 0.0, + "gamma": 0.5, + "I_e": 0.0, + "regular_spike_arrival": False, + "surrogate_gradient_function": surrogate_gradient_type, + "t_ref": 3.0, + "V_m": 0.0, + "V_th": 0.6, + } + + gen_spk_in = nest.Create("spike_generator", 1) + nrns_in = nest.Create("parrot_neuron", 1) + nrns_rec = nest.Create("eprop_iaf", 1, params_nrn_rec) + + params_mm_rec = { + "interval": duration["step"], + "record_from": ["surrogate_gradient", "V_m"], + } + + mm_rec = nest.Create("multimeter", params_mm_rec) + + params_conn_one_to_one = {"rule": "one_to_one"} + params_syn = { + "synapse_model": "eprop_synapse", + "delay": duration["step"], + "weight": 0.3, + } + params_syn_static = { + "synapse_model": "static_synapse", + "delay": duration["step"], + } + + nest.SetStatus(gen_spk_in, {"spike_times": [1.0, 2.0, 3.0, 5.0, 9.0, 11.0]}) + + nest.Connect(gen_spk_in, nrns_in, params_conn_one_to_one, params_syn_static) + nest.Connect(nrns_in, nrns_rec, params_conn_one_to_one, params_syn) + nest.Connect(mm_rec, nrns_rec, params_conn_one_to_one, params_syn_static) + + nest.Simulate(duration["sim"]) + events_mm_rec = mm_rec.get("events") + surrogate_gradient = events_mm_rec["surrogate_gradient"][-5:] + + assert np.allclose(surrogate_gradient, surrogate_gradient_reference, rtol=1e-8) diff --git a/testsuite/pytests/test_labeled_synapses.py b/testsuite/pytests/test_labeled_synapses.py index 77e4e2bfd0..0ff60ceec7 100644 --- a/testsuite/pytests/test_labeled_synapses.py +++ b/testsuite/pytests/test_labeled_synapses.py @@ -57,13 +57,20 @@ def default_network(self, syn_model): self.urbanczik_synapses = ["urbanczik_synapse", "urbanczik_synapse_lbl", "urbanczik_synapse_hpc"] - self.eprop_synapses = ["eprop_synapse_bsshslm_2020", "eprop_synapse_bsshslm_2020_hpc"] - self.eprop_connections = [ + self.eprop_synapses_bsshslm_2020 = ["eprop_synapse_bsshslm_2020", "eprop_synapse_bsshslm_2020_hpc"] + self.eprop_connections_bsshslm_2020 = [ "eprop_learning_signal_connection_bsshslm_2020", "eprop_learning_signal_connection_bsshslm_2020_lbl", "eprop_learning_signal_connection_bsshslm_2020_hpc", ] + self.eprop_synapses = ["eprop_synapse", "eprop_synapse_hpc"] + self.eprop_connections = [ + "eprop_learning_signal_connection", + "eprop_learning_signal_connection_lbl", + "eprop_learning_signal_connection_hpc", + ] + # create neurons that accept all synapse connections (especially gap # junctions)... hh_psc_alpha_gap is only available with GSL, hence the # skipIf above @@ -88,12 +95,18 @@ def default_network(self, syn_model): syns = nest.GetDefaults("pp_cond_exp_mc_urbanczik")["receptor_types"] r_type = syns["soma_exc"] - if syn_model in self.eprop_synapses: + if syn_model in self.eprop_synapses_bsshslm_2020: neurons = nest.Create("eprop_iaf_bsshslm_2020", 5) - if syn_model in self.eprop_connections: + if syn_model in self.eprop_connections_bsshslm_2020: neurons = nest.Create("eprop_readout_bsshslm_2020", 5) + nest.Create("eprop_iaf_bsshslm_2020", 5) + if syn_model in self.eprop_synapses: + neurons = nest.Create("eprop_iaf", 5) + + if syn_model in self.eprop_connections: + neurons = nest.Create("eprop_readout", 5) + nest.Create("eprop_iaf", 5) + return neurons, r_type def test_SetLabelToSynapseOnConnect(self): @@ -197,7 +210,13 @@ def test_SetLabelToNotLabeledSynapse(self): nest.SetDefaults(syn, {"synapse_label": 123}) # plain connection - if syn in self.eprop_connections or syn in self.eprop_synapses: + if ( + syn + in self.eprop_connections_bsshslm_2020 + + self.eprop_connections + + self.eprop_synapses_bsshslm_2020 + + self.eprop_synapses + ): # try set on connect with self.assertRaises(nest.kernel.NESTError): nest.Connect( diff --git a/testsuite/pytests/test_refractory.py b/testsuite/pytests/test_refractory.py index d2d7601ca6..fe49b2662a 100644 --- a/testsuite/pytests/test_refractory.py +++ b/testsuite/pytests/test_refractory.py @@ -60,6 +60,8 @@ neurons_eprop = [ "eprop_iaf_bsshslm_2020", "eprop_iaf_adapt_bsshslm_2020", + "eprop_iaf", + "eprop_iaf_adapt", ] # Models that first clamp the membrane potential at a higher value @@ -85,6 +87,7 @@ "siegert_neuron", # This one does not connect to voltmeter "step_rate_generator", # No regular neuron model "eprop_readout_bsshslm_2020", # This one does not spike + "eprop_readout", # This one does not spike "iaf_tum_2000", # Hijacks the offset field, see #2912 "iaf_bw_2001", # Hijacks the offset field, see #2912 "iaf_bw_2001_exact", # Hijacks the offset field, see #2912 diff --git a/testsuite/pytests/test_sp/test_disconnect.py b/testsuite/pytests/test_sp/test_disconnect.py index 40c22e19a8..79ad552435 100644 --- a/testsuite/pytests/test_sp/test_disconnect.py +++ b/testsuite/pytests/test_sp/test_disconnect.py @@ -78,6 +78,11 @@ def test_synapse_deletion_one_to_one_no_sp(self): syn_dict["delay"] = nest.resolution elif "eprop_learning_signal_connection_bsshslm_2020" in syn_model: neurons = nest.Create("eprop_readout_bsshslm_2020", 2) + nest.Create("eprop_iaf_bsshslm_2020", 2) + elif "eprop_synapse" in syn_model: + neurons = nest.Create("eprop_iaf", 4) + syn_dict["delay"] = nest.resolution + elif "eprop_learning_signal_connection" in syn_model: + neurons = nest.Create("eprop_readout", 2) + nest.Create("eprop_iaf", 2) else: neurons = nest.Create("iaf_psc_alpha", 4) diff --git a/testsuite/pytests/test_sp/test_disconnect_multiple.py b/testsuite/pytests/test_sp/test_disconnect_multiple.py index e0529f1645..a3eeebf0ed 100644 --- a/testsuite/pytests/test_sp/test_disconnect_multiple.py +++ b/testsuite/pytests/test_sp/test_disconnect_multiple.py @@ -54,6 +54,11 @@ def setUp(self): "eprop_learning_signal_connection_bsshslm_2020", "eprop_learning_signal_connection_bsshslm_2020_lbl", "eprop_learning_signal_connection_bsshslm_2020_hpc", + "eprop_synapse", + "eprop_synapse_hpc", + "eprop_learning_signal_connection", + "eprop_learning_signal_connection_lbl", + "eprop_learning_signal_connection_hpc", "sic_connection", ] diff --git a/testsuite/regressiontests/ticket-310.sli b/testsuite/regressiontests/ticket-310.sli index 424fc7f96d..eed11a5b25 100644 --- a/testsuite/regressiontests/ticket-310.sli +++ b/testsuite/regressiontests/ticket-310.sli @@ -40,8 +40,13 @@ /skip_list [ /iaf_chxk_2008 % non-standard spiking conditions /correlospinmatrix_detector % not a neuron /eprop_iaf_bsshslm_2020 % no ArchivingNode, thus no t_spike - /eprop_iaf_adapt_bsshslm_2020 % no ArchivingNode, thus no t_spike - /eprop_readout_bsshslm_2020 % no ArchivingNode, thus no t_spike + /eprop_iaf_adapt_bsshslm_2020 % no ArchivingNode, thus no t_spike + /eprop_readout_bsshslm_2020 % no ArchivingNode, thus no t_spike + /eprop_iaf % no ArchivingNode, thus no t_spike + /eprop_iaf_adapt % no ArchivingNode, thus no t_spike + /eprop_iaf_psc_delta % no ArchivingNode, thus no t_spike + /eprop_iaf_psc_delta_adapt % no ArchivingNode, thus no t_spike + /eprop_readout % no ArchivingNode, thus no t_spike ] def { diff --git a/testsuite/regressiontests/ticket-421.sli b/testsuite/regressiontests/ticket-421.sli index 88e3e452fe..a8ccb8b038 100644 --- a/testsuite/regressiontests/ticket-421.sli +++ b/testsuite/regressiontests/ticket-421.sli @@ -53,7 +53,8 @@ Author: Hans Ekkehard Plesser, 2010-05-05 /aeif_psc_exp /aeif_psc_alpha /aeif_psc_delta /aeif_cond_beta_multisynapse /hh_cond_exp_traub /hh_cond_beta_gap_traub /hh_psc_alpha /hh_psc_alpha_clopath /hh_psc_alpha_gap /ht_neuron /ht_neuron_fs /iaf_cond_exp_sfa_rr /izhikevich - /eprop_iaf_bsshslm_2020 /eprop_iaf_adapt_bsshslm_2020 /eprop_readout_bsshslm_2020] def + /eprop_iaf_bsshslm_2020 /eprop_iaf_adapt_bsshslm_2020 /eprop_readout_bsshslm_2020 + /eprop_iaf /eprop_iaf_adapt /eprop_iaf_psc_delta /eprop_iaf_psc_delta_adapt /eprop_readout] def % use power-of-two resolution to avoid round-off problems /res -3 dexp def diff --git a/testsuite/regressiontests/ticket-618.sli b/testsuite/regressiontests/ticket-618.sli index 969ef4a9c9..fe56d45ebd 100644 --- a/testsuite/regressiontests/ticket-618.sli +++ b/testsuite/regressiontests/ticket-618.sli @@ -46,7 +46,8 @@ Author: Hans Ekkehard Plesser, 2012-12-11 M_ERROR setverbosity -/excluded_models [ /eprop_iaf_bsshslm_2020 /eprop_iaf_adapt_bsshslm_2020 /eprop_readout_bsshslm_2020 /iaf_bw_2001 ] def +/excluded_models [ /eprop_iaf_bsshslm_2020 /eprop_iaf_adapt_bsshslm_2020 /eprop_readout_bsshslm_2020 + /eprop_iaf /eprop_iaf_adapt /eprop_iaf_psc_delta /eprop_iaf_psc_delta_adapt /eprop_readout /iaf_bw_2001 ] def { GetKernelStatus /node_models get