From 2db85a06780c8abbd6e997e647e9469334e184c2 Mon Sep 17 00:00:00 2001
From: Pooja Babu
Date: Thu, 29 Jun 2023 16:55:24 +0200
Subject: [PATCH] Update templates
---
.../@NEURON_NAME@.cu.jinja2 | 174 +++++++++++-------
1 file changed, 105 insertions(+), 69 deletions(-)
diff --git a/pynestml/codegeneration/resources_nest_gpu/@NEURON_NAME@.cu.jinja2 b/pynestml/codegeneration/resources_nest_gpu/@NEURON_NAME@.cu.jinja2
index 1e076dbca..98e99778a 100644
--- a/pynestml/codegeneration/resources_nest_gpu/@NEURON_NAME@.cu.jinja2
+++ b/pynestml/codegeneration/resources_nest_gpu/@NEURON_NAME@.cu.jinja2
@@ -45,28 +45,28 @@ extern __constant__ float NESTGPUTimeResolution;
#define {{ printer_no_origin.print(variable) }} param[i_{{ printer_no_origin.print(variable) }}]
{%- endfor %}
-__device__
-double propagator_32( double tau_syn, double tau, double C, double h )
-{
- const double P32_linear = 1.0 / ( 2.0 * C * tau * tau ) * h * h
- * ( tau_syn - tau ) * exp( -h / tau );
- const double P32_singular = h / C * exp( -h / tau );
- const double P32 =
- -tau / ( C * ( 1.0 - tau / tau_syn ) ) * exp( -h / tau_syn )
- * expm1( h * ( 1.0 / tau_syn - 1.0 / tau ) );
-
- const double dev_P32 = fabs( P32 - P32_singular );
-
- if ( tau == tau_syn || ( fabs( tau - tau_syn ) < 0.1 && dev_P32 > 2.0
- * fabs( P32_linear ) ) )
- {
- return P32_singular;
- }
- else
- {
- return P32;
- }
-}
+{#__device__#}
+{#double propagator_32( double tau_syn, double tau, double C, double h )#}
+{#{#}
+{# const double P32_linear = 1.0 / ( 2.0 * C * tau * tau ) * h * h#}
+{# * ( tau_syn - tau ) * exp( -h / tau );#}
+{# const double P32_singular = h / C * exp( -h / tau );#}
+{# const double P32 =#}
+{# -tau / ( C * ( 1.0 - tau / tau_syn ) ) * exp( -h / tau_syn )#}
+{# * expm1( h * ( 1.0 / tau_syn - 1.0 / tau ) );#}
+{##}
+{# const double dev_P32 = fabs( P32 - P32_singular );#}
+{##}
+{# if ( tau == tau_syn || ( fabs( tau - tau_syn ) < 0.1 && dev_P32 > 2.0#}
+{# * fabs( P32_linear ) ) )#}
+{# {#}
+{# return P32_singular;#}
+{# }#}
+{# else#}
+{# {#}
+{# return P32;#}
+{# }#}
+{#}#}
__global__ void {{ neuronName }}_Calibrate(int n_node, float *param_arr,
@@ -76,12 +76,22 @@ __global__ void {{ neuronName }}_Calibrate(int n_node, float *param_arr,
if (i_neuron < n_node) {
float *param = param_arr + n_param*i_neuron;
- P11ex = exp( -h / tau_ex );
- P11in = exp( -h / tau_in );
- P22 = exp( -h / tau_m );
- P21ex = (float)propagator_32( tau_ex, tau_m, C_m, h );
- P21in = (float)propagator_32( tau_in, tau_m, C_m, h );
- P20 = tau_m / C_m * ( 1.0 - P22 );
+{# P11ex = exp( -h / tau_ex );#}
+{# P11in = exp( -h / tau_in );#}
+{# P22 = exp( -h / tau_m );#}
+{# P21ex = (float)propagator_32( tau_ex, tau_m, C_m, h );#}
+{# P21in = (float)propagator_32( tau_in, tau_m, C_m, h );#}
+{# P20 = tau_m / C_m * ( 1.0 - P22 );#}
+{%- filter indent(4,True) %}
+{%- for internals_block in neuron.get_internals_blocks() %}
+{%- for decl in internals_block.get_declarations() %}
+{%- for variable in decl.get_variables() %}
+{%- set variable_symbol = variable.get_scope().resolve_to_symbol(variable.get_complete_name(), SymbolKind.VARIABLE) %}
+{%- include "directives/MemberInitialization.jinja2" %}
+{%- endfor %}
+{%- endfor %}
+{%- endfor %}
+{%- endfilter %}
}
}
@@ -94,22 +104,35 @@ __global__ void {{ neuronName }}_Update(int n_node, int i_node_0, float *var_arr
float *var = var_arr + n_var*i_neuron;
float *param = param_arr + n_param*i_neuron;
- if ( refractory_step > 0.0 ) {
- // neuron is absolute refractory
- refractory_step -= 1.0;
- }
- else { // neuron is not refractory, so evolve V
- V_m_rel = V_m_rel * P22 + I_syn_ex * P21ex + I_syn_in * P21in + I_e * P20;
- }
- // exponential decaying PSCs
- I_syn_ex *= P11ex;
- I_syn_in *= P11in;
-
- if (V_m_rel >= Theta_rel ) { // threshold crossing
- PushSpike(i_node_0 + i_neuron, 1.0);
- V_m_rel = V_reset_rel;
- refractory_step = (int)round(t_ref/NESTGPUTimeResolution);
- }
+{# if ( refractory_step > 0.0 ) {#}
+{# // neuron is absolute refractory#}
+{# refractory_step -= 1.0;#}
+{# }#}
+{# else { // neuron is not refractory, so evolve V#}
+{# V_m_rel = V_m_rel * P22 + I_syn_ex * P21ex + I_syn_in * P21in + I_e * P20;#}
+{# }#}
+{# // exponential decaying PSCs#}
+{# I_syn_ex *= P11ex;#}
+{# I_syn_in *= P11in;#}
+{##}
+{# if (V_m_rel >= Theta_rel ) { // threshold crossing#}
+{# PushSpike(i_node_0 + i_neuron, 1.0);#}
+{# V_m_rel = V_reset_rel;#}
+{# refractory_step = (int)round(t_ref/NESTGPUTimeResolution);#}
+{# }#}
+{%- if neuron.get_update_blocks() %}
+{%- filter indent(2) %}
+{%- for block in neuron.get_update_blocks() %}
+{%- set ast = block.get_block() %}
+{%- if ast.print_comment('*')|length > 1 %}
+/*
+ {{ast.print_comment('*')}}
+ */
+{%- endif %}
+{%- include "directives/Block.jinja2" %}
+{%- endfor %}
+{%- endfilter %}
+{%- endif %}
}
}
@@ -136,29 +159,42 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/,
scal_var_name_ = {{ neuronName }}_scal_var_name;
scal_param_name_ = {{ neuronName }}_scal_param_name;
- SetScalParam(0, n_node, "tau_m", 10.0 ); // in ms
- SetScalParam(0, n_node, "C_m", 250.0 ); // in pF
- SetScalParam(0, n_node, "E_L", -70.0 ); // in mV
- SetScalParam(0, n_node, "I_e", 0.0 ); // in pA
- SetScalParam(0, n_node, "Theta_rel", -55.0 - (-70.0) ); // relative to E_L_
- SetScalParam(0, n_node, "V_reset_rel", -70.0 - (-70.0) ); // relative to E_L_
- SetScalParam(0, n_node, "tau_ex", 2.0 ); // in ms
- SetScalParam(0, n_node, "tau_in", 2.0 ); // in ms
- // SetScalParam(0, n_node, "rho", 0.01 ); // in 1/s
- // SetScalParam(0, n_node, "delta", 0.0 ); // in mV
- SetScalParam(0, n_node, "t_ref", 2.0 ); // in ms
- SetScalParam(0, n_node, "den_delay", 0.0); // in ms
- SetScalParam(0, n_node, "P20", 0.0);
- SetScalParam(0, n_node, "P11ex", 0.0);
- SetScalParam(0, n_node, "P11in", 0.0);
- SetScalParam(0, n_node, "P21ex", 0.0);
- SetScalParam(0, n_node, "P21in", 0.0);
- SetScalParam(0, n_node, "P22", 0.0);
-
- SetScalVar(0, n_node, "I_syn_ex", 0.0 );
- SetScalVar(0, n_node, "I_syn_in", 0.0 );
- SetScalVar(0, n_node, "V_m_rel", -70.0 - (-70.0) ); // in mV, relative to E_L
- SetScalVar(0, n_node, "refractory_step", 0 );
+{# SetScalParam(0, n_node, "tau_m", 10.0 ); // in ms#}
+{# SetScalParam(0, n_node, "C_m", 250.0 ); // in pF#}
+{# SetScalParam(0, n_node, "E_L", -70.0 ); // in mV#}
+{# SetScalParam(0, n_node, "I_e", 0.0 ); // in pA#}
+{# SetScalParam(0, n_node, "Theta_rel", -55.0 - (-70.0) ); // relative to E_L_#}
+{# SetScalParam(0, n_node, "V_reset_rel", -70.0 - (-70.0) ); // relative to E_L_#}
+{# SetScalParam(0, n_node, "tau_ex", 2.0 ); // in ms#}
+{# SetScalParam(0, n_node, "tau_in", 2.0 ); // in ms#}
+{# // SetScalParam(0, n_node, "rho", 0.01 ); // in 1/s#}
+{# // SetScalParam(0, n_node, "delta", 0.0 ); // in mV#}
+{# SetScalParam(0, n_node, "t_ref", 2.0 ); // in ms#}
+{# SetScalParam(0, n_node, "den_delay", 0.0); // in ms#}
+{# SetScalParam(0, n_node, "P20", 0.0);#}
+{# SetScalParam(0, n_node, "P11ex", 0.0);#}
+{# SetScalParam(0, n_node, "P11in", 0.0);#}
+{# SetScalParam(0, n_node, "P21ex", 0.0);#}
+{# SetScalParam(0, n_node, "P21in", 0.0);#}
+{# SetScalParam(0, n_node, "P22", 0.0);#}
+{##}
+{# SetScalVar(0, n_node, "I_syn_ex", 0.0 );#}
+{# SetScalVar(0, n_node, "I_syn_in", 0.0 );#}
+{# SetScalVar(0, n_node, "V_m_rel", -70.0 - (-70.0) ); // in mV, relative to E_L#}
+{# SetScalVar(0, n_node, "refractory_step", 0 );#}
+
+{%- filter indent(2) %}
+{%- for variable in neuron.get_parameter_symbols() %}
+ SetScalParam(0, n_node, {{ printer_no_origin.print(variable) }}, {{printer.print(variable.get_declaring_expression())}}); // as {{variable.get_type_symbol().print_symbol()}}
+{%- endfor %}
+{%- endfilter %}
+
+
+{%- filter indent(2) %}
+{%- for variable in neuron.get_internal_symbols() %}
+ SetScalParam(0, n_node, {{ printer_no_origin.print(variable) }}, 0.0);
+{%- endfor %}
+{%- endfilter %}
// multiplication factor of input signal is always 1 for all nodes
float input_weight = 1.0;
@@ -169,11 +205,11 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/,
port_weight_port_step_ = 0;
// input spike signal is stored in I_syn_ex, I_syn_in
- port_input_arr_ = GetVarArr() + GetScalVarIdx("I_syn_ex");
+ port_input_arr_ = GetVarArr() + GetScalVarIdx("I_kernel_exc__X__exc_spikes");
port_input_arr_step_ = n_var_;
port_input_port_step_ = 1;
- den_delay_arr_ = GetParamArr() + GetScalParamIdx("den_delay");
+{# den_delay_arr_ = GetParamArr() + GetScalParamIdx("den_delay");#}
return 0;
}