Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into static_syn
Browse files Browse the repository at this point in the history
  • Loading branch information
C.A.P. Linssen committed Aug 9, 2023
2 parents b6397bf + 6e1c66e commit 4ccf133
Show file tree
Hide file tree
Showing 10 changed files with 121 additions and 33 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/nestml-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
nest_branch: ["v2.20.2", "v3.0", "v3.1", "v3.2", "v3.3", "v3.4", "master"]
nest_branch: ["v2.20.2", "v3.0", "v3.1", "v3.2", "v3.3", "v3.4", "v3.5", "master"]
fail-fast: false
steps:
# Checkout the repository contents
Expand Down
15 changes: 15 additions & 0 deletions doc/nestml_language/nestml_language_concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,21 @@ The following functions are predefined in NESTML and can be used out of the box:
* - ``tanh``
- x
- Returns the hyperbolic tangent of x. The type of x and the return type are Real.
* - ``erf``
- x
- Returns the error function of x. The type of x and the return type are Real.
* - ``erfc``
- x
- Returns the complementary error function of x. The type of x and the return type are Real.
* - ``ceil``
- x
- Returns the ceil of x. The type of x and the return type are Real.
* - ``floor``
- x
- Returns the floor of x. The type of x and the return type are Real.
* - ``round``
- x
- Returns the rounded value of x. The type of x and the return type are Real.
* - ``random_normal``
- mean, std
- Returns a sample from a normal (Gaussian) distribution with parameters "mean" and "standard deviation"
Expand Down
29 changes: 16 additions & 13 deletions pynestml/codegeneration/nest_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,23 @@ def detect_nest_version(cls) -> str:
pass
if "DataConnect" in dir(nest):
nest_version = "v2.20.2"
elif "kernel_status" not in dir(nest): # added in v3.1
nest_version = "v3.0"
elif "Kplus" in syn.get().keys(): # "Kplus" trace variable is made accessible via get_status() in master
nest_version = "master"
elif "prepared" in nest.GetKernelStatus().keys(): # "prepared" key was added after v3.3 release
nest_version = "v3.4"
elif "tau_u_bar_minus" in neuron.get().keys(): # added in v3.3
nest_version = "v3.3"
elif "tau_Ca" in vt.get().keys(): # removed in v3.2
nest_version = "v3.1"
nest_version = "v2.20.2"
else:
nest_version = "v3.2"
nest_version = "v" + nest.__version__
if nest_version.startswith("v3.5"):
if "post0.dev0" in nest_version:
nest_version = "master"
else:
if "kernel_status" not in dir(nest): # added in v3.1
nest_version = "v3.0"
elif "prepared" in nest.GetKernelStatus().keys(): # "prepared" key was added after v3.3 release
nest_version = "v3.4"
elif "tau_u_bar_minus" in neuron.get().keys(): # added in v3.3
nest_version = "v3.3"
elif "tau_Ca" in vt.get().keys(): # removed in v3.2
nest_version = "v3.1"
else:
nest_version = "v3.2"
except ModuleNotFoundError:
nest_version = ""
Expand Down
9 changes: 9 additions & 0 deletions pynestml/codegeneration/printers/cpp_function_call_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,15 @@ def _print_function_call_format_string(self, function_call: ASTFunctionCall) ->
if function_name == PredefinedFunctions.ERFC:
return 'std::erfc({!s})'

if function_name == PredefinedFunctions.CEIL:
return 'std::ceil({!s})'

if function_name == PredefinedFunctions.FLOOR:
return 'std::floor({!s})'

if function_name == PredefinedFunctions.ROUND:
return 'std::round({!s})'

if function_name == PredefinedFunctions.EXPM1:
return 'numerics::expm1({!s})'

Expand Down
8 changes: 5 additions & 3 deletions pynestml/codegeneration/printers/nestml_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,14 +371,14 @@ def print_input_port(self, node: ASTInputPort) -> str:
ret += " " + self.print(node.get_datatype()) + " "
if node.has_size_parameter():
ret += "[" + node.get_size_parameter() + "]"
ret += "<-"
ret += "<- "
if node.has_input_qualifiers():
for qual in node.get_input_qualifiers():
ret += self.print(qual) + " "
if node.is_spike():
ret += "spike"
else:
ret += "current"
ret += "continuous"
ret += print_sl_comment(node.in_comment) + "\n"
return ret

Expand Down Expand Up @@ -442,7 +442,9 @@ def print_kernel(self, node: ASTKernel) -> str:

def print_output_block(self, node: ASTOutputBlock) -> str:
ret = print_ml_comments(node.pre_comments, self.indent, False)
ret += print_n_spaces(self.indent) + "output: " + ("spike" if node.is_spike() else "current")
ret += print_n_spaces(self.indent) + "output:\n"
ret += print_n_spaces(self.indent + 4)
ret += "spike" if node.is_spike() else "continuous"
ret += print_sl_comment(node.in_comment)
ret += "\n"
return ret
Expand Down
45 changes: 45 additions & 0 deletions pynestml/symbols/predefined_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ class PredefinedFunctions:
MAX The callee name of the max function.
MIN The callee name of the min function.
ABS The callee name of the abs function.
CEIL The callee name of the ceil function.
FLOOR The callee name of the floor function.
ROUND The callee name of the round function.
INTEGRATE_ODES The callee name of the integrate_odes function.
CONVOLVE The callee name of the convolve function.
name2function A dict of function symbols as currently defined.
Expand Down Expand Up @@ -81,6 +84,9 @@ class PredefinedFunctions:
MAX = 'max'
MIN = 'min'
ABS = 'abs'
CEIL = 'ceil'
FLOOR = 'floor'
ROUND = 'round'
INTEGRATE_ODES = 'integrate_odes'
CONVOLVE = 'convolve'
DELIVER_SPIKE = 'deliver_spike'
Expand Down Expand Up @@ -116,6 +122,9 @@ def register_functions(cls):
cls.__register_max_function()
cls.__register_min_function()
cls.__register_abs_function()
cls.__register_ceil_function()
cls.__register_floor_function()
cls.__register_round_function()
cls.__register_integrated_odes_function()
cls.__register_convolve()
cls.__register_deliver_spike()
Expand Down Expand Up @@ -417,6 +426,42 @@ def __register_abs_function(cls):
element_reference=None, is_predefined=True)
cls.name2function[cls.ABS] = symbol

@classmethod
def __register_ceil_function(cls):
"""
Registers the ceil function.
"""
params = list()
params.append(PredefinedTypes.get_template_type(0))
symbol = FunctionSymbol(name=cls.CEIL, param_types=params,
return_type=PredefinedTypes.get_template_type(0),
element_reference=None, is_predefined=True)
cls.name2function[cls.CEIL] = symbol

@classmethod
def __register_floor_function(cls):
"""
Registers the floor function.
"""
params = list()
params.append(PredefinedTypes.get_template_type(0))
symbol = FunctionSymbol(name=cls.FLOOR, param_types=params,
return_type=PredefinedTypes.get_template_type(0),
element_reference=None, is_predefined=True)
cls.name2function[cls.FLOOR] = symbol

@classmethod
def __register_round_function(cls):
"""
Registers the round function.
"""
params = list()
params.append(PredefinedTypes.get_template_type(0))
symbol = FunctionSymbol(name=cls.ROUND, param_types=params,
return_type=PredefinedTypes.get_template_type(0),
element_reference=None, is_predefined=True)
cls.name2function[cls.ROUND] = symbol

@classmethod
def __register_integrated_odes_function(cls):
"""
Expand Down
4 changes: 2 additions & 2 deletions pynestml/symbols/variable_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def has_delay_parameter(self):
Returns whether this variable has a delay value associated with it.
:return: bool
"""
return self.delay_parameter is not None and type(self.delay_parameter) == str
return self.delay_parameter is not None and isinstance(self.delay_parameter, str)

def get_block_type(self):
"""
Expand Down Expand Up @@ -425,7 +425,7 @@ def equals(self, other):
:return: True if equal, otherwise False.
:rtype: bool
"""
return (type(self) != type(other)
return (isinstance(other, type(self))
and self.get_referenced_object() == other.get_referenced_object()
and self.get_symbol_name() == other.get_symbol_name()
and self.get_corresponding_scope() == other.get_corresponding_scope()
Expand Down
2 changes: 1 addition & 1 deletion pynestml/utils/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,7 +1197,7 @@ def get_expr_from_kernel_var(cls, kernel: ASTKernel, var_name: str) -> Union[AST
"""
Get the expression using the kernel variable
"""
assert type(var_name) == str
assert isinstance(var_name, str)
for var, expr in zip(kernel.get_variables(), kernel.get_expressions()):
if var.get_complete_name() == var_name:
return expr
Expand Down
6 changes: 6 additions & 0 deletions tests/nest_tests/resources/MathFunctionTest.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,16 @@ neuron math_function_test:
log10_state real = 0.
erf_state real = 0.
erfc_state real = 0.
ceil_state real = 0.
floor_state real = 0.
round_state real = 0.

update:
ln_state = ln(x)
log10_state = log10(x)
erf_state = erf(x)
erfc_state = erfc(x)
ceil_state = ceil(x / 10.)
floor_state = floor(x / 10.)
round_state = round(x / 10.)
x = x + 1.
34 changes: 21 additions & 13 deletions tests/nest_tests/test_nest_math_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,35 +52,43 @@ def test_math_function(self):
nrn = nest.Create("math_function_test_nestml")
mm = nest.Create("multimeter")

ln_state_specifier = "ln_state"
log10_state_specifier = "log10_state"
erf_state_specifier = "erf_state"
erfc_state_specifier = "erfc_state"
nest.SetStatus(mm, {"record_from": ["x", ln_state_specifier, log10_state_specifier, erf_state_specifier, erfc_state_specifier]})
nest.SetStatus(mm, {"record_from": ["x", "ln_state", "log10_state", "erf_state", "erfc_state", "ceil_state", "floor_state", "round_state"]})

nest.Connect(mm, nrn)

nest.Simulate(100.)

if nest_version.startswith("v2"):
timevec = nest.GetStatus(mm, "events")[0]["x"]
ln_state_ts = nest.GetStatus(mm, "events")[0][ln_state_specifier]
log10_state_ts = nest.GetStatus(mm, "events")[0][log10_state_specifier]
erf_state_ts = nest.GetStatus(mm, "events")[0][erf_state_specifier]
erfc_state_ts = nest.GetStatus(mm, "events")[0][erfc_state_specifier]
ln_state_ts = nest.GetStatus(mm, "events")[0]["ln_state"]
log10_state_ts = nest.GetStatus(mm, "events")[0]["log10_state"]
erf_state_ts = nest.GetStatus(mm, "events")[0]["erf_state"]
erfc_state_ts = nest.GetStatus(mm, "events")[0]["erfc_state"]
ceil_state_ts = nest.GetStatus(mm, "events")[0]["ceil_state"]
floor_state_ts = nest.GetStatus(mm, "events")[0]["floor_state"]
round_state_ts = nest.GetStatus(mm, "events")[0]["round_state"]
else:
timevec = mm.get("events")["x"]
ln_state_ts = mm.get("events")[ln_state_specifier]
log10_state_ts = mm.get("events")[log10_state_specifier]
erf_state_ts = mm.get("events")[erf_state_specifier]
erfc_state_ts = mm.get("events")[erfc_state_specifier]
ln_state_ts = mm.get("events")["ln_state"]
log10_state_ts = mm.get("events")["log10_state"]
erf_state_ts = mm.get("events")["erf_state"]
erfc_state_ts = mm.get("events")["erfc_state"]
ceil_state_ts = mm.get("events")["ceil_state"]
floor_state_ts = mm.get("events")["floor_state"]
round_state_ts = mm.get("events")["round_state"]

ref_ln_state_ts = np.log(timevec - 1)
ref_log10_state_ts = np.log10(timevec - 1)
ref_erf_state_ts = sp.special.erf(timevec - 1)
ref_erfc_state_ts = sp.special.erfc(timevec - 1)
ref_ceil_state_ts = np.ceil((timevec - 1) / 10)
ref_floor_state_ts = np.floor((timevec - 1) / 10)
ref_round_state_ts = np.round((timevec - 1) / 10)

np.testing.assert_allclose(ln_state_ts, ref_ln_state_ts)
np.testing.assert_allclose(log10_state_ts, ref_log10_state_ts)
np.testing.assert_allclose(erf_state_ts, ref_erf_state_ts)
np.testing.assert_allclose(erfc_state_ts, ref_erfc_state_ts)
np.testing.assert_allclose(ceil_state_ts, ref_ceil_state_ts)
np.testing.assert_allclose(floor_state_ts, ref_floor_state_ts)
np.testing.assert_allclose(round_state_ts, ref_round_state_ts)

0 comments on commit 4ccf133

Please sign in to comment.