Skip to content

Commit

Permalink
Add ceil, floor and round functions (#929)
Browse files Browse the repository at this point in the history
  • Loading branch information
clinssen authored Aug 8, 2023
1 parent c0b21ea commit 94714a9
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 13 deletions.
15 changes: 15 additions & 0 deletions doc/nestml_language/nestml_language_concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,21 @@ The following functions are predefined in NESTML and can be used out of the box:
* - ``tanh``
- x
- Returns the hyperbolic tangent of x. The type of x and the return type are Real.
* - ``erf``
- x
- Returns the error function of x. The type of x and the return type are Real.
* - ``erfc``
- x
- Returns the complementary error function of x. The type of x and the return type are Real.
* - ``ceil``
- x
- Returns the ceil of x. The type of x and the return type are Real.
* - ``floor``
- x
- Returns the floor of x. The type of x and the return type are Real.
* - ``round``
- x
- Returns the rounded value of x. The type of x and the return type are Real.
* - ``random_normal``
- mean, std
- Returns a sample from a normal (Gaussian) distribution with parameters "mean" and "standard deviation"
Expand Down
9 changes: 9 additions & 0 deletions pynestml/codegeneration/printers/cpp_function_call_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,15 @@ def _print_function_call_format_string(self, function_call: ASTFunctionCall) ->
if function_name == PredefinedFunctions.ERFC:
return 'std::erfc({!s})'

if function_name == PredefinedFunctions.CEIL:
return 'std::ceil({!s})'

if function_name == PredefinedFunctions.FLOOR:
return 'std::floor({!s})'

if function_name == PredefinedFunctions.ROUND:
return 'std::round({!s})'

if function_name == PredefinedFunctions.EXPM1:
return 'numerics::expm1({!s})'

Expand Down
45 changes: 45 additions & 0 deletions pynestml/symbols/predefined_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ class PredefinedFunctions:
MAX The callee name of the max function.
MIN The callee name of the min function.
ABS The callee name of the abs function.
CEIL The callee name of the ceil function.
FLOOR The callee name of the floor function.
ROUND The callee name of the round function.
INTEGRATE_ODES The callee name of the integrate_odes function.
CONVOLVE The callee name of the convolve function.
name2function A dict of function symbols as currently defined.
Expand Down Expand Up @@ -81,6 +84,9 @@ class PredefinedFunctions:
MAX = 'max'
MIN = 'min'
ABS = 'abs'
CEIL = 'ceil'
FLOOR = 'floor'
ROUND = 'round'
INTEGRATE_ODES = 'integrate_odes'
CONVOLVE = 'convolve'
DELIVER_SPIKE = 'deliver_spike'
Expand Down Expand Up @@ -116,6 +122,9 @@ def register_functions(cls):
cls.__register_max_function()
cls.__register_min_function()
cls.__register_abs_function()
cls.__register_ceil_function()
cls.__register_floor_function()
cls.__register_round_function()
cls.__register_integrated_odes_function()
cls.__register_convolve()
cls.__register_deliver_spike()
Expand Down Expand Up @@ -417,6 +426,42 @@ def __register_abs_function(cls):
element_reference=None, is_predefined=True)
cls.name2function[cls.ABS] = symbol

@classmethod
def __register_ceil_function(cls):
"""
Registers the ceil function.
"""
params = list()
params.append(PredefinedTypes.get_template_type(0))
symbol = FunctionSymbol(name=cls.CEIL, param_types=params,
return_type=PredefinedTypes.get_template_type(0),
element_reference=None, is_predefined=True)
cls.name2function[cls.CEIL] = symbol

@classmethod
def __register_floor_function(cls):
"""
Registers the floor function.
"""
params = list()
params.append(PredefinedTypes.get_template_type(0))
symbol = FunctionSymbol(name=cls.FLOOR, param_types=params,
return_type=PredefinedTypes.get_template_type(0),
element_reference=None, is_predefined=True)
cls.name2function[cls.FLOOR] = symbol

@classmethod
def __register_round_function(cls):
"""
Registers the round function.
"""
params = list()
params.append(PredefinedTypes.get_template_type(0))
symbol = FunctionSymbol(name=cls.ROUND, param_types=params,
return_type=PredefinedTypes.get_template_type(0),
element_reference=None, is_predefined=True)
cls.name2function[cls.ROUND] = symbol

@classmethod
def __register_integrated_odes_function(cls):
"""
Expand Down
6 changes: 6 additions & 0 deletions tests/nest_tests/resources/MathFunctionTest.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,16 @@ neuron math_function_test:
log10_state real = 0.
erf_state real = 0.
erfc_state real = 0.
ceil_state real = 0.
floor_state real = 0.
round_state real = 0.

update:
ln_state = ln(x)
log10_state = log10(x)
erf_state = erf(x)
erfc_state = erfc(x)
ceil_state = ceil(x / 10.)
floor_state = floor(x / 10.)
round_state = round(x / 10.)
x = x + 1.
34 changes: 21 additions & 13 deletions tests/nest_tests/test_nest_math_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,35 +52,43 @@ def test_math_function(self):
nrn = nest.Create("math_function_test_nestml")
mm = nest.Create("multimeter")

ln_state_specifier = "ln_state"
log10_state_specifier = "log10_state"
erf_state_specifier = "erf_state"
erfc_state_specifier = "erfc_state"
nest.SetStatus(mm, {"record_from": ["x", ln_state_specifier, log10_state_specifier, erf_state_specifier, erfc_state_specifier]})
nest.SetStatus(mm, {"record_from": ["x", "ln_state", "log10_state", "erf_state", "erfc_state", "ceil_state", "floor_state", "round_state"]})

nest.Connect(mm, nrn)

nest.Simulate(100.)

if nest_version.startswith("v2"):
timevec = nest.GetStatus(mm, "events")[0]["x"]
ln_state_ts = nest.GetStatus(mm, "events")[0][ln_state_specifier]
log10_state_ts = nest.GetStatus(mm, "events")[0][log10_state_specifier]
erf_state_ts = nest.GetStatus(mm, "events")[0][erf_state_specifier]
erfc_state_ts = nest.GetStatus(mm, "events")[0][erfc_state_specifier]
ln_state_ts = nest.GetStatus(mm, "events")[0]["ln_state"]
log10_state_ts = nest.GetStatus(mm, "events")[0]["log10_state"]
erf_state_ts = nest.GetStatus(mm, "events")[0]["erf_state"]
erfc_state_ts = nest.GetStatus(mm, "events")[0]["erfc_state"]
ceil_state_ts = nest.GetStatus(mm, "events")[0]["ceil_state"]
floor_state_ts = nest.GetStatus(mm, "events")[0]["floor_state"]
round_state_ts = nest.GetStatus(mm, "events")[0]["round_state"]
else:
timevec = mm.get("events")["x"]
ln_state_ts = mm.get("events")[ln_state_specifier]
log10_state_ts = mm.get("events")[log10_state_specifier]
erf_state_ts = mm.get("events")[erf_state_specifier]
erfc_state_ts = mm.get("events")[erfc_state_specifier]
ln_state_ts = mm.get("events")["ln_state"]
log10_state_ts = mm.get("events")["log10_state"]
erf_state_ts = mm.get("events")["erf_state"]
erfc_state_ts = mm.get("events")["erfc_state"]
ceil_state_ts = mm.get("events")["ceil_state"]
floor_state_ts = mm.get("events")["floor_state"]
round_state_ts = mm.get("events")["round_state"]

ref_ln_state_ts = np.log(timevec - 1)
ref_log10_state_ts = np.log10(timevec - 1)
ref_erf_state_ts = sp.special.erf(timevec - 1)
ref_erfc_state_ts = sp.special.erfc(timevec - 1)
ref_ceil_state_ts = np.ceil((timevec - 1) / 10)
ref_floor_state_ts = np.floor((timevec - 1) / 10)
ref_round_state_ts = np.round((timevec - 1) / 10)

np.testing.assert_allclose(ln_state_ts, ref_ln_state_ts)
np.testing.assert_allclose(log10_state_ts, ref_log10_state_ts)
np.testing.assert_allclose(erf_state_ts, ref_erf_state_ts)
np.testing.assert_allclose(erfc_state_ts, ref_erfc_state_ts)
np.testing.assert_allclose(ceil_state_ts, ref_ceil_state_ts)
np.testing.assert_allclose(floor_state_ts, ref_floor_state_ts)
np.testing.assert_allclose(round_state_ts, ref_round_state_ts)

0 comments on commit 94714a9

Please sign in to comment.