diff --git a/doc/nestml_language/nestml_language_concepts.rst b/doc/nestml_language/nestml_language_concepts.rst index 242181f7e..7b948d30f 100644 --- a/doc/nestml_language/nestml_language_concepts.rst +++ b/doc/nestml_language/nestml_language_concepts.rst @@ -512,6 +512,21 @@ The following functions are predefined in NESTML and can be used out of the box: * - ``tanh`` - x - Returns the hyperbolic tangent of x. The type of x and the return type are Real. + * - ``erf`` + - x + - Returns the error function of x. The type of x and the return type are Real. + * - ``erfc`` + - x + - Returns the complementary error function of x. The type of x and the return type are Real. + * - ``ceil`` + - x + - Returns the ceil of x. The type of x and the return type are Real. + * - ``floor`` + - x + - Returns the floor of x. The type of x and the return type are Real. + * - ``round`` + - x + - Returns the rounded value of x. The type of x and the return type are Real. * - ``random_normal`` - mean, std - Returns a sample from a normal (Gaussian) distribution with parameters "mean" and "standard deviation" diff --git a/pynestml/codegeneration/printers/cpp_function_call_printer.py b/pynestml/codegeneration/printers/cpp_function_call_printer.py index e07a60a1f..126c60529 100644 --- a/pynestml/codegeneration/printers/cpp_function_call_printer.py +++ b/pynestml/codegeneration/printers/cpp_function_call_printer.py @@ -115,6 +115,15 @@ def _print_function_call_format_string(self, function_call: ASTFunctionCall) -> if function_name == PredefinedFunctions.ERFC: return 'std::erfc({!s})' + if function_name == PredefinedFunctions.CEIL: + return 'std::ceil({!s})' + + if function_name == PredefinedFunctions.FLOOR: + return 'std::floor({!s})' + + if function_name == PredefinedFunctions.ROUND: + return 'std::round({!s})' + if function_name == PredefinedFunctions.EXPM1: return 'numerics::expm1({!s})' diff --git a/pynestml/symbols/predefined_functions.py b/pynestml/symbols/predefined_functions.py index 0541e1444..45e7b3302 100644 --- a/pynestml/symbols/predefined_functions.py +++ b/pynestml/symbols/predefined_functions.py @@ -53,6 +53,9 @@ class PredefinedFunctions: MAX The callee name of the max function. MIN The callee name of the min function. ABS The callee name of the abs function. + CEIL The callee name of the ceil function. + FLOOR The callee name of the floor function. + ROUND The callee name of the round function. INTEGRATE_ODES The callee name of the integrate_odes function. CONVOLVE The callee name of the convolve function. name2function A dict of function symbols as currently defined. @@ -81,6 +84,9 @@ class PredefinedFunctions: MAX = 'max' MIN = 'min' ABS = 'abs' + CEIL = 'ceil' + FLOOR = 'floor' + ROUND = 'round' INTEGRATE_ODES = 'integrate_odes' CONVOLVE = 'convolve' DELIVER_SPIKE = 'deliver_spike' @@ -116,6 +122,9 @@ def register_functions(cls): cls.__register_max_function() cls.__register_min_function() cls.__register_abs_function() + cls.__register_ceil_function() + cls.__register_floor_function() + cls.__register_round_function() cls.__register_integrated_odes_function() cls.__register_convolve() cls.__register_deliver_spike() @@ -417,6 +426,42 @@ def __register_abs_function(cls): element_reference=None, is_predefined=True) cls.name2function[cls.ABS] = symbol + @classmethod + def __register_ceil_function(cls): + """ + Registers the ceil function. + """ + params = list() + params.append(PredefinedTypes.get_template_type(0)) + symbol = FunctionSymbol(name=cls.CEIL, param_types=params, + return_type=PredefinedTypes.get_template_type(0), + element_reference=None, is_predefined=True) + cls.name2function[cls.CEIL] = symbol + + @classmethod + def __register_floor_function(cls): + """ + Registers the floor function. + """ + params = list() + params.append(PredefinedTypes.get_template_type(0)) + symbol = FunctionSymbol(name=cls.FLOOR, param_types=params, + return_type=PredefinedTypes.get_template_type(0), + element_reference=None, is_predefined=True) + cls.name2function[cls.FLOOR] = symbol + + @classmethod + def __register_round_function(cls): + """ + Registers the round function. + """ + params = list() + params.append(PredefinedTypes.get_template_type(0)) + symbol = FunctionSymbol(name=cls.ROUND, param_types=params, + return_type=PredefinedTypes.get_template_type(0), + element_reference=None, is_predefined=True) + cls.name2function[cls.ROUND] = symbol + @classmethod def __register_integrated_odes_function(cls): """ diff --git a/tests/nest_tests/resources/MathFunctionTest.nestml b/tests/nest_tests/resources/MathFunctionTest.nestml index 0c97e0f9f..9bc650fea 100644 --- a/tests/nest_tests/resources/MathFunctionTest.nestml +++ b/tests/nest_tests/resources/MathFunctionTest.nestml @@ -30,10 +30,16 @@ neuron math_function_test: log10_state real = 0. erf_state real = 0. erfc_state real = 0. + ceil_state real = 0. + floor_state real = 0. + round_state real = 0. update: ln_state = ln(x) log10_state = log10(x) erf_state = erf(x) erfc_state = erfc(x) + ceil_state = ceil(x / 10.) + floor_state = floor(x / 10.) + round_state = round(x / 10.) x = x + 1. diff --git a/tests/nest_tests/test_nest_math_function.py b/tests/nest_tests/test_nest_math_function.py index 982b1f315..32938dd3f 100644 --- a/tests/nest_tests/test_nest_math_function.py +++ b/tests/nest_tests/test_nest_math_function.py @@ -52,11 +52,7 @@ def test_math_function(self): nrn = nest.Create("math_function_test_nestml") mm = nest.Create("multimeter") - ln_state_specifier = "ln_state" - log10_state_specifier = "log10_state" - erf_state_specifier = "erf_state" - erfc_state_specifier = "erfc_state" - nest.SetStatus(mm, {"record_from": ["x", ln_state_specifier, log10_state_specifier, erf_state_specifier, erfc_state_specifier]}) + nest.SetStatus(mm, {"record_from": ["x", "ln_state", "log10_state", "erf_state", "erfc_state", "ceil_state", "floor_state", "round_state"]}) nest.Connect(mm, nrn) @@ -64,23 +60,35 @@ def test_math_function(self): if nest_version.startswith("v2"): timevec = nest.GetStatus(mm, "events")[0]["x"] - ln_state_ts = nest.GetStatus(mm, "events")[0][ln_state_specifier] - log10_state_ts = nest.GetStatus(mm, "events")[0][log10_state_specifier] - erf_state_ts = nest.GetStatus(mm, "events")[0][erf_state_specifier] - erfc_state_ts = nest.GetStatus(mm, "events")[0][erfc_state_specifier] + ln_state_ts = nest.GetStatus(mm, "events")[0]["ln_state"] + log10_state_ts = nest.GetStatus(mm, "events")[0]["log10_state"] + erf_state_ts = nest.GetStatus(mm, "events")[0]["erf_state"] + erfc_state_ts = nest.GetStatus(mm, "events")[0]["erfc_state"] + ceil_state_ts = nest.GetStatus(mm, "events")[0]["ceil_state"] + floor_state_ts = nest.GetStatus(mm, "events")[0]["floor_state"] + round_state_ts = nest.GetStatus(mm, "events")[0]["round_state"] else: timevec = mm.get("events")["x"] - ln_state_ts = mm.get("events")[ln_state_specifier] - log10_state_ts = mm.get("events")[log10_state_specifier] - erf_state_ts = mm.get("events")[erf_state_specifier] - erfc_state_ts = mm.get("events")[erfc_state_specifier] + ln_state_ts = mm.get("events")["ln_state"] + log10_state_ts = mm.get("events")["log10_state"] + erf_state_ts = mm.get("events")["erf_state"] + erfc_state_ts = mm.get("events")["erfc_state"] + ceil_state_ts = mm.get("events")["ceil_state"] + floor_state_ts = mm.get("events")["floor_state"] + round_state_ts = mm.get("events")["round_state"] ref_ln_state_ts = np.log(timevec - 1) ref_log10_state_ts = np.log10(timevec - 1) ref_erf_state_ts = sp.special.erf(timevec - 1) ref_erfc_state_ts = sp.special.erfc(timevec - 1) + ref_ceil_state_ts = np.ceil((timevec - 1) / 10) + ref_floor_state_ts = np.floor((timevec - 1) / 10) + ref_round_state_ts = np.round((timevec - 1) / 10) np.testing.assert_allclose(ln_state_ts, ref_ln_state_ts) np.testing.assert_allclose(log10_state_ts, ref_log10_state_ts) np.testing.assert_allclose(erf_state_ts, ref_erf_state_ts) np.testing.assert_allclose(erfc_state_ts, ref_erfc_state_ts) + np.testing.assert_allclose(ceil_state_ts, ref_ceil_state_ts) + np.testing.assert_allclose(floor_state_ts, ref_floor_state_ts) + np.testing.assert_allclose(round_state_ts, ref_round_state_ts)