From 2db85a06780c8abbd6e997e647e9469334e184c2 Mon Sep 17 00:00:00 2001 From: Pooja Babu Date: Thu, 29 Jun 2023 16:55:24 +0200 Subject: [PATCH] Update templates --- .../@NEURON_NAME@.cu.jinja2 | 174 +++++++++++------- 1 file changed, 105 insertions(+), 69 deletions(-) diff --git a/pynestml/codegeneration/resources_nest_gpu/@NEURON_NAME@.cu.jinja2 b/pynestml/codegeneration/resources_nest_gpu/@NEURON_NAME@.cu.jinja2 index 1e076dbca..98e99778a 100644 --- a/pynestml/codegeneration/resources_nest_gpu/@NEURON_NAME@.cu.jinja2 +++ b/pynestml/codegeneration/resources_nest_gpu/@NEURON_NAME@.cu.jinja2 @@ -45,28 +45,28 @@ extern __constant__ float NESTGPUTimeResolution; #define {{ printer_no_origin.print(variable) }} param[i_{{ printer_no_origin.print(variable) }}] {%- endfor %} -__device__ -double propagator_32( double tau_syn, double tau, double C, double h ) -{ - const double P32_linear = 1.0 / ( 2.0 * C * tau * tau ) * h * h - * ( tau_syn - tau ) * exp( -h / tau ); - const double P32_singular = h / C * exp( -h / tau ); - const double P32 = - -tau / ( C * ( 1.0 - tau / tau_syn ) ) * exp( -h / tau_syn ) - * expm1( h * ( 1.0 / tau_syn - 1.0 / tau ) ); - - const double dev_P32 = fabs( P32 - P32_singular ); - - if ( tau == tau_syn || ( fabs( tau - tau_syn ) < 0.1 && dev_P32 > 2.0 - * fabs( P32_linear ) ) ) - { - return P32_singular; - } - else - { - return P32; - } -} +{#__device__#} +{#double propagator_32( double tau_syn, double tau, double C, double h )#} +{#{#} +{# const double P32_linear = 1.0 / ( 2.0 * C * tau * tau ) * h * h#} +{# * ( tau_syn - tau ) * exp( -h / tau );#} +{# const double P32_singular = h / C * exp( -h / tau );#} +{# const double P32 =#} +{# -tau / ( C * ( 1.0 - tau / tau_syn ) ) * exp( -h / tau_syn )#} +{# * expm1( h * ( 1.0 / tau_syn - 1.0 / tau ) );#} +{##} +{# const double dev_P32 = fabs( P32 - P32_singular );#} +{##} +{# if ( tau == tau_syn || ( fabs( tau - tau_syn ) < 0.1 && dev_P32 > 2.0#} +{# * fabs( P32_linear ) ) )#} +{# {#} +{# return P32_singular;#} +{# }#} +{# else#} +{# {#} +{# return P32;#} +{# }#} +{#}#} __global__ void {{ neuronName }}_Calibrate(int n_node, float *param_arr, @@ -76,12 +76,22 @@ __global__ void {{ neuronName }}_Calibrate(int n_node, float *param_arr, if (i_neuron < n_node) { float *param = param_arr + n_param*i_neuron; - P11ex = exp( -h / tau_ex ); - P11in = exp( -h / tau_in ); - P22 = exp( -h / tau_m ); - P21ex = (float)propagator_32( tau_ex, tau_m, C_m, h ); - P21in = (float)propagator_32( tau_in, tau_m, C_m, h ); - P20 = tau_m / C_m * ( 1.0 - P22 ); +{# P11ex = exp( -h / tau_ex );#} +{# P11in = exp( -h / tau_in );#} +{# P22 = exp( -h / tau_m );#} +{# P21ex = (float)propagator_32( tau_ex, tau_m, C_m, h );#} +{# P21in = (float)propagator_32( tau_in, tau_m, C_m, h );#} +{# P20 = tau_m / C_m * ( 1.0 - P22 );#} +{%- filter indent(4,True) %} +{%- for internals_block in neuron.get_internals_blocks() %} +{%- for decl in internals_block.get_declarations() %} +{%- for variable in decl.get_variables() %} +{%- set variable_symbol = variable.get_scope().resolve_to_symbol(variable.get_complete_name(), SymbolKind.VARIABLE) %} +{%- include "directives/MemberInitialization.jinja2" %} +{%- endfor %} +{%- endfor %} +{%- endfor %} +{%- endfilter %} } } @@ -94,22 +104,35 @@ __global__ void {{ neuronName }}_Update(int n_node, int i_node_0, float *var_arr float *var = var_arr + n_var*i_neuron; float *param = param_arr + n_param*i_neuron; - if ( refractory_step > 0.0 ) { - // neuron is absolute refractory - refractory_step -= 1.0; - } - else { // neuron is not refractory, so evolve V - V_m_rel = V_m_rel * P22 + I_syn_ex * P21ex + I_syn_in * P21in + I_e * P20; - } - // exponential decaying PSCs - I_syn_ex *= P11ex; - I_syn_in *= P11in; - - if (V_m_rel >= Theta_rel ) { // threshold crossing - PushSpike(i_node_0 + i_neuron, 1.0); - V_m_rel = V_reset_rel; - refractory_step = (int)round(t_ref/NESTGPUTimeResolution); - } +{# if ( refractory_step > 0.0 ) {#} +{# // neuron is absolute refractory#} +{# refractory_step -= 1.0;#} +{# }#} +{# else { // neuron is not refractory, so evolve V#} +{# V_m_rel = V_m_rel * P22 + I_syn_ex * P21ex + I_syn_in * P21in + I_e * P20;#} +{# }#} +{# // exponential decaying PSCs#} +{# I_syn_ex *= P11ex;#} +{# I_syn_in *= P11in;#} +{##} +{# if (V_m_rel >= Theta_rel ) { // threshold crossing#} +{# PushSpike(i_node_0 + i_neuron, 1.0);#} +{# V_m_rel = V_reset_rel;#} +{# refractory_step = (int)round(t_ref/NESTGPUTimeResolution);#} +{# }#} +{%- if neuron.get_update_blocks() %} +{%- filter indent(2) %} +{%- for block in neuron.get_update_blocks() %} +{%- set ast = block.get_block() %} +{%- if ast.print_comment('*')|length > 1 %} +/* + {{ast.print_comment('*')}} + */ +{%- endif %} +{%- include "directives/Block.jinja2" %} +{%- endfor %} +{%- endfilter %} +{%- endif %} } } @@ -136,29 +159,42 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/, scal_var_name_ = {{ neuronName }}_scal_var_name; scal_param_name_ = {{ neuronName }}_scal_param_name; - SetScalParam(0, n_node, "tau_m", 10.0 ); // in ms - SetScalParam(0, n_node, "C_m", 250.0 ); // in pF - SetScalParam(0, n_node, "E_L", -70.0 ); // in mV - SetScalParam(0, n_node, "I_e", 0.0 ); // in pA - SetScalParam(0, n_node, "Theta_rel", -55.0 - (-70.0) ); // relative to E_L_ - SetScalParam(0, n_node, "V_reset_rel", -70.0 - (-70.0) ); // relative to E_L_ - SetScalParam(0, n_node, "tau_ex", 2.0 ); // in ms - SetScalParam(0, n_node, "tau_in", 2.0 ); // in ms - // SetScalParam(0, n_node, "rho", 0.01 ); // in 1/s - // SetScalParam(0, n_node, "delta", 0.0 ); // in mV - SetScalParam(0, n_node, "t_ref", 2.0 ); // in ms - SetScalParam(0, n_node, "den_delay", 0.0); // in ms - SetScalParam(0, n_node, "P20", 0.0); - SetScalParam(0, n_node, "P11ex", 0.0); - SetScalParam(0, n_node, "P11in", 0.0); - SetScalParam(0, n_node, "P21ex", 0.0); - SetScalParam(0, n_node, "P21in", 0.0); - SetScalParam(0, n_node, "P22", 0.0); - - SetScalVar(0, n_node, "I_syn_ex", 0.0 ); - SetScalVar(0, n_node, "I_syn_in", 0.0 ); - SetScalVar(0, n_node, "V_m_rel", -70.0 - (-70.0) ); // in mV, relative to E_L - SetScalVar(0, n_node, "refractory_step", 0 ); +{# SetScalParam(0, n_node, "tau_m", 10.0 ); // in ms#} +{# SetScalParam(0, n_node, "C_m", 250.0 ); // in pF#} +{# SetScalParam(0, n_node, "E_L", -70.0 ); // in mV#} +{# SetScalParam(0, n_node, "I_e", 0.0 ); // in pA#} +{# SetScalParam(0, n_node, "Theta_rel", -55.0 - (-70.0) ); // relative to E_L_#} +{# SetScalParam(0, n_node, "V_reset_rel", -70.0 - (-70.0) ); // relative to E_L_#} +{# SetScalParam(0, n_node, "tau_ex", 2.0 ); // in ms#} +{# SetScalParam(0, n_node, "tau_in", 2.0 ); // in ms#} +{# // SetScalParam(0, n_node, "rho", 0.01 ); // in 1/s#} +{# // SetScalParam(0, n_node, "delta", 0.0 ); // in mV#} +{# SetScalParam(0, n_node, "t_ref", 2.0 ); // in ms#} +{# SetScalParam(0, n_node, "den_delay", 0.0); // in ms#} +{# SetScalParam(0, n_node, "P20", 0.0);#} +{# SetScalParam(0, n_node, "P11ex", 0.0);#} +{# SetScalParam(0, n_node, "P11in", 0.0);#} +{# SetScalParam(0, n_node, "P21ex", 0.0);#} +{# SetScalParam(0, n_node, "P21in", 0.0);#} +{# SetScalParam(0, n_node, "P22", 0.0);#} +{##} +{# SetScalVar(0, n_node, "I_syn_ex", 0.0 );#} +{# SetScalVar(0, n_node, "I_syn_in", 0.0 );#} +{# SetScalVar(0, n_node, "V_m_rel", -70.0 - (-70.0) ); // in mV, relative to E_L#} +{# SetScalVar(0, n_node, "refractory_step", 0 );#} + +{%- filter indent(2) %} +{%- for variable in neuron.get_parameter_symbols() %} + SetScalParam(0, n_node, {{ printer_no_origin.print(variable) }}, {{printer.print(variable.get_declaring_expression())}}); // as {{variable.get_type_symbol().print_symbol()}} +{%- endfor %} +{%- endfilter %} + + +{%- filter indent(2) %} +{%- for variable in neuron.get_internal_symbols() %} + SetScalParam(0, n_node, {{ printer_no_origin.print(variable) }}, 0.0); +{%- endfor %} +{%- endfilter %} // multiplication factor of input signal is always 1 for all nodes float input_weight = 1.0; @@ -169,11 +205,11 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/, port_weight_port_step_ = 0; // input spike signal is stored in I_syn_ex, I_syn_in - port_input_arr_ = GetVarArr() + GetScalVarIdx("I_syn_ex"); + port_input_arr_ = GetVarArr() + GetScalVarIdx("I_kernel_exc__X__exc_spikes"); port_input_arr_step_ = n_var_; port_input_port_step_ = 1; - den_delay_arr_ = GetParamArr() + GetScalParamIdx("den_delay"); +{# den_delay_arr_ = GetParamArr() + GetScalParamIdx("den_delay");#} return 0; }