Skip to content

Commit

Permalink
Update templates
Browse files Browse the repository at this point in the history
  • Loading branch information
pnbabu committed Jun 29, 2023
1 parent 7e1d4ac commit 2db85a0
Showing 1 changed file with 105 additions and 69 deletions.
174 changes: 105 additions & 69 deletions pynestml/codegeneration/resources_nest_gpu/@[email protected]
Original file line number Diff line number Diff line change
Expand Up @@ -45,28 +45,28 @@ extern __constant__ float NESTGPUTimeResolution;
#define {{ printer_no_origin.print(variable) }} param[i_{{ printer_no_origin.print(variable) }}]
{%- endfor %}

__device__
double propagator_32( double tau_syn, double tau, double C, double h )
{
const double P32_linear = 1.0 / ( 2.0 * C * tau * tau ) * h * h
* ( tau_syn - tau ) * exp( -h / tau );
const double P32_singular = h / C * exp( -h / tau );
const double P32 =
-tau / ( C * ( 1.0 - tau / tau_syn ) ) * exp( -h / tau_syn )
* expm1( h * ( 1.0 / tau_syn - 1.0 / tau ) );

const double dev_P32 = fabs( P32 - P32_singular );

if ( tau == tau_syn || ( fabs( tau - tau_syn ) < 0.1 && dev_P32 > 2.0
* fabs( P32_linear ) ) )
{
return P32_singular;
}
else
{
return P32;
}
}
{#__device__#}
{#double propagator_32( double tau_syn, double tau, double C, double h )#}
{#{#}
{# const double P32_linear = 1.0 / ( 2.0 * C * tau * tau ) * h * h#}
{# * ( tau_syn - tau ) * exp( -h / tau );#}
{# const double P32_singular = h / C * exp( -h / tau );#}
{# const double P32 =#}
{# -tau / ( C * ( 1.0 - tau / tau_syn ) ) * exp( -h / tau_syn )#}
{# * expm1( h * ( 1.0 / tau_syn - 1.0 / tau ) );#}
{##}
{# const double dev_P32 = fabs( P32 - P32_singular );#}
{##}
{# if ( tau == tau_syn || ( fabs( tau - tau_syn ) < 0.1 && dev_P32 > 2.0#}
{# * fabs( P32_linear ) ) )#}
{# {#}
{# return P32_singular;#}
{# }#}
{# else#}
{# {#}
{# return P32;#}
{# }#}
{#}#}


__global__ void {{ neuronName }}_Calibrate(int n_node, float *param_arr,
Expand All @@ -76,12 +76,22 @@ __global__ void {{ neuronName }}_Calibrate(int n_node, float *param_arr,
if (i_neuron < n_node) {
float *param = param_arr + n_param*i_neuron;

P11ex = exp( -h / tau_ex );
P11in = exp( -h / tau_in );
P22 = exp( -h / tau_m );
P21ex = (float)propagator_32( tau_ex, tau_m, C_m, h );
P21in = (float)propagator_32( tau_in, tau_m, C_m, h );
P20 = tau_m / C_m * ( 1.0 - P22 );
{# P11ex = exp( -h / tau_ex );#}
{# P11in = exp( -h / tau_in );#}
{# P22 = exp( -h / tau_m );#}
{# P21ex = (float)propagator_32( tau_ex, tau_m, C_m, h );#}
{# P21in = (float)propagator_32( tau_in, tau_m, C_m, h );#}
{# P20 = tau_m / C_m * ( 1.0 - P22 );#}
{%- filter indent(4,True) %}
{%- for internals_block in neuron.get_internals_blocks() %}
{%- for decl in internals_block.get_declarations() %}
{%- for variable in decl.get_variables() %}
{%- set variable_symbol = variable.get_scope().resolve_to_symbol(variable.get_complete_name(), SymbolKind.VARIABLE) %}
{%- include "directives/MemberInitialization.jinja2" %}
{%- endfor %}
{%- endfor %}
{%- endfor %}
{%- endfilter %}
}
}

Expand All @@ -94,22 +104,35 @@ __global__ void {{ neuronName }}_Update(int n_node, int i_node_0, float *var_arr
float *var = var_arr + n_var*i_neuron;
float *param = param_arr + n_param*i_neuron;

if ( refractory_step > 0.0 ) {
// neuron is absolute refractory
refractory_step -= 1.0;
}
else { // neuron is not refractory, so evolve V
V_m_rel = V_m_rel * P22 + I_syn_ex * P21ex + I_syn_in * P21in + I_e * P20;
}
// exponential decaying PSCs
I_syn_ex *= P11ex;
I_syn_in *= P11in;

if (V_m_rel >= Theta_rel ) { // threshold crossing
PushSpike(i_node_0 + i_neuron, 1.0);
V_m_rel = V_reset_rel;
refractory_step = (int)round(t_ref/NESTGPUTimeResolution);
}
{# if ( refractory_step > 0.0 ) {#}
{# // neuron is absolute refractory#}
{# refractory_step -= 1.0;#}
{# }#}
{# else { // neuron is not refractory, so evolve V#}
{# V_m_rel = V_m_rel * P22 + I_syn_ex * P21ex + I_syn_in * P21in + I_e * P20;#}
{# }#}
{# // exponential decaying PSCs#}
{# I_syn_ex *= P11ex;#}
{# I_syn_in *= P11in;#}
{##}
{# if (V_m_rel >= Theta_rel ) { // threshold crossing#}
{# PushSpike(i_node_0 + i_neuron, 1.0);#}
{# V_m_rel = V_reset_rel;#}
{# refractory_step = (int)round(t_ref/NESTGPUTimeResolution);#}
{# }#}
{%- if neuron.get_update_blocks() %}
{%- filter indent(2) %}
{%- for block in neuron.get_update_blocks() %}
{%- set ast = block.get_block() %}
{%- if ast.print_comment('*')|length > 1 %}
/*
{{ast.print_comment('*')}}
*/
{%- endif %}
{%- include "directives/Block.jinja2" %}
{%- endfor %}
{%- endfilter %}
{%- endif %}
}
}

Expand All @@ -136,29 +159,42 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/,
scal_var_name_ = {{ neuronName }}_scal_var_name;
scal_param_name_ = {{ neuronName }}_scal_param_name;

SetScalParam(0, n_node, "tau_m", 10.0 ); // in ms
SetScalParam(0, n_node, "C_m", 250.0 ); // in pF
SetScalParam(0, n_node, "E_L", -70.0 ); // in mV
SetScalParam(0, n_node, "I_e", 0.0 ); // in pA
SetScalParam(0, n_node, "Theta_rel", -55.0 - (-70.0) ); // relative to E_L_
SetScalParam(0, n_node, "V_reset_rel", -70.0 - (-70.0) ); // relative to E_L_
SetScalParam(0, n_node, "tau_ex", 2.0 ); // in ms
SetScalParam(0, n_node, "tau_in", 2.0 ); // in ms
// SetScalParam(0, n_node, "rho", 0.01 ); // in 1/s
// SetScalParam(0, n_node, "delta", 0.0 ); // in mV
SetScalParam(0, n_node, "t_ref", 2.0 ); // in ms
SetScalParam(0, n_node, "den_delay", 0.0); // in ms
SetScalParam(0, n_node, "P20", 0.0);
SetScalParam(0, n_node, "P11ex", 0.0);
SetScalParam(0, n_node, "P11in", 0.0);
SetScalParam(0, n_node, "P21ex", 0.0);
SetScalParam(0, n_node, "P21in", 0.0);
SetScalParam(0, n_node, "P22", 0.0);

SetScalVar(0, n_node, "I_syn_ex", 0.0 );
SetScalVar(0, n_node, "I_syn_in", 0.0 );
SetScalVar(0, n_node, "V_m_rel", -70.0 - (-70.0) ); // in mV, relative to E_L
SetScalVar(0, n_node, "refractory_step", 0 );
{# SetScalParam(0, n_node, "tau_m", 10.0 ); // in ms#}
{# SetScalParam(0, n_node, "C_m", 250.0 ); // in pF#}
{# SetScalParam(0, n_node, "E_L", -70.0 ); // in mV#}
{# SetScalParam(0, n_node, "I_e", 0.0 ); // in pA#}
{# SetScalParam(0, n_node, "Theta_rel", -55.0 - (-70.0) ); // relative to E_L_#}
{# SetScalParam(0, n_node, "V_reset_rel", -70.0 - (-70.0) ); // relative to E_L_#}
{# SetScalParam(0, n_node, "tau_ex", 2.0 ); // in ms#}
{# SetScalParam(0, n_node, "tau_in", 2.0 ); // in ms#}
{# // SetScalParam(0, n_node, "rho", 0.01 ); // in 1/s#}
{# // SetScalParam(0, n_node, "delta", 0.0 ); // in mV#}
{# SetScalParam(0, n_node, "t_ref", 2.0 ); // in ms#}
{# SetScalParam(0, n_node, "den_delay", 0.0); // in ms#}
{# SetScalParam(0, n_node, "P20", 0.0);#}
{# SetScalParam(0, n_node, "P11ex", 0.0);#}
{# SetScalParam(0, n_node, "P11in", 0.0);#}
{# SetScalParam(0, n_node, "P21ex", 0.0);#}
{# SetScalParam(0, n_node, "P21in", 0.0);#}
{# SetScalParam(0, n_node, "P22", 0.0);#}
{##}
{# SetScalVar(0, n_node, "I_syn_ex", 0.0 );#}
{# SetScalVar(0, n_node, "I_syn_in", 0.0 );#}
{# SetScalVar(0, n_node, "V_m_rel", -70.0 - (-70.0) ); // in mV, relative to E_L#}
{# SetScalVar(0, n_node, "refractory_step", 0 );#}

{%- filter indent(2) %}
{%- for variable in neuron.get_parameter_symbols() %}
SetScalParam(0, n_node, {{ printer_no_origin.print(variable) }}, {{printer.print(variable.get_declaring_expression())}}); // as {{variable.get_type_symbol().print_symbol()}}
{%- endfor %}
{%- endfilter %}


{%- filter indent(2) %}
{%- for variable in neuron.get_internal_symbols() %}
SetScalParam(0, n_node, {{ printer_no_origin.print(variable) }}, 0.0);
{%- endfor %}
{%- endfilter %}

// multiplication factor of input signal is always 1 for all nodes
float input_weight = 1.0;
Expand All @@ -169,11 +205,11 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/,
port_weight_port_step_ = 0;

// input spike signal is stored in I_syn_ex, I_syn_in
port_input_arr_ = GetVarArr() + GetScalVarIdx("I_syn_ex");
port_input_arr_ = GetVarArr() + GetScalVarIdx("I_kernel_exc__X__exc_spikes");
port_input_arr_step_ = n_var_;
port_input_port_step_ = 1;

den_delay_arr_ = GetParamArr() + GetScalParamIdx("den_delay");
{# den_delay_arr_ = GetParamArr() + GetScalParamIdx("den_delay");#}

return 0;
}
Expand Down

0 comments on commit 2db85a0

Please sign in to comment.