Skip to content

Commit

Permalink
run context condition checks only once, after model parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
C.A.P. Linssen committed Oct 1, 2024
1 parent e6565b5 commit 3f1e731
Show file tree
Hide file tree
Showing 36 changed files with 1,248 additions and 1,181 deletions.
38 changes: 8 additions & 30 deletions pynestml/cocos/co_co_all_variables_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,10 @@ class CoCoAllVariablesDefined(CoCo):
"""

@classmethod
def check_co_co(cls, node: ASTModel, after_ast_rewrite: bool = False):
def check_co_co(cls, node: ASTModel):
"""
Checks if this coco applies for the handed over neuron. Models which contain undefined variables are not correct.
:param node: a single neuron instance.
:param after_ast_rewrite: indicates whether this coco is checked after the code generator has done rewriting of the abstract syntax tree. If True, checks are not as rigorous. Use False where possible.
"""
# for each variable in all expressions, check if the variable has been defined previously
expression_collector_visitor = ASTExpressionCollectorVisitor()
Expand All @@ -62,32 +61,6 @@ def check_co_co(cls, node: ASTModel, after_ast_rewrite: bool = False):

# test if the symbol has been defined at least
if symbol is None:
if after_ast_rewrite: # after ODE-toolbox transformations, convolutions are replaced by state variables, so cannot perform this check properly
symbol2 = node.get_scope().resolve_to_symbol(var.get_name(), SymbolKind.VARIABLE)
if symbol2 is not None:
# an inline expression defining this variable name (ignoring differential order) exists
if "__X__" in str(symbol2): # if this variable was the result of a convolution...
continue
else:
# for kernels, also allow derivatives of that kernel to appear

inline_expr_names = []
inline_exprs = []
for equations_block in node.get_equations_blocks():
inline_expr_names.extend([inline_expr.variable_name for inline_expr in equations_block.get_inline_expressions()])
inline_exprs.extend(equations_block.get_inline_expressions())

if var.get_name() in inline_expr_names:
inline_expr_idx = inline_expr_names.index(var.get_name())
inline_expr = inline_exprs[inline_expr_idx]
from pynestml.utils.ast_utils import ASTUtils
if ASTUtils.inline_aliases_convolution(inline_expr):
symbol2 = node.get_scope().resolve_to_symbol(var.get_name(), SymbolKind.VARIABLE)
if symbol2 is not None:
# actually, no problem detected, skip error
# XXX: TODO: check that differential order is less than or equal to that of the kernel
continue

# check if this symbol is actually a type, e.g. "mV" in the expression "(1 + 2) * mV"
symbol2 = var.get_scope().resolve_to_symbol(var.get_complete_name(), SymbolKind.TYPE)
if symbol2 is not None:
Expand All @@ -106,9 +79,14 @@ def check_co_co(cls, node: ASTModel, after_ast_rewrite: bool = False):
# in this case its ok if it is recursive or defined later on
continue

if symbol.is_predefined:
continue

if symbol.block_type == BlockType.LOCAL and symbol.get_referenced_object().get_source_position().before(var.get_source_position()):
continue

# check if it has been defined before usage, except for predefined symbols, input ports and variables added by the AST transformation functions
if (not symbol.is_predefined) \
and symbol.block_type != BlockType.INPUT \
if symbol.block_type != BlockType.INPUT \
and not symbol.get_referenced_object().get_source_position().is_added_source_position():
# except for parameters, those can be defined after
if ((not symbol.get_referenced_object().get_source_position().before(var.get_source_position()))
Expand Down
1 change: 1 addition & 0 deletions pynestml/cocos/co_co_function_unique.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,5 @@ def check_co_co(cls, model: ASTModel):
log_level=LoggingLevel.ERROR,
message=message, code=code)
checked.append(funcA)

checked_funcs_names.append(func.get_name())
6 changes: 3 additions & 3 deletions pynestml/cocos/co_co_illegal_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
from pynestml.meta_model.ast_inline_expression import ASTInlineExpression

from pynestml.utils.ast_source_location import ASTSourceLocation
from pynestml.meta_model.ast_declaration import ASTDeclaration
from pynestml.cocos.co_co import CoCo
from pynestml.meta_model.ast_declaration import ASTDeclaration
from pynestml.meta_model.ast_inline_expression import ASTInlineExpression
from pynestml.symbols.error_type_symbol import ErrorTypeSymbol
from pynestml.symbols.predefined_types import PredefinedTypes
from pynestml.utils.ast_source_location import ASTSourceLocation
from pynestml.utils.logger import LoggingLevel, Logger
from pynestml.utils.logging_helper import LoggingHelper
from pynestml.utils.messages import Messages
Expand Down
49 changes: 36 additions & 13 deletions pynestml/cocos/co_co_no_kernels_except_in_convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@
from typing import List

from pynestml.cocos.co_co import CoCo
from pynestml.meta_model.ast_declaration import ASTDeclaration
from pynestml.meta_model.ast_external_variable import ASTExternalVariable
from pynestml.meta_model.ast_function_call import ASTFunctionCall
from pynestml.meta_model.ast_kernel import ASTKernel
from pynestml.meta_model.ast_model import ASTModel
from pynestml.meta_model.ast_node import ASTNode
from pynestml.meta_model.ast_variable import ASTVariable
from pynestml.symbols.predefined_functions import PredefinedFunctions
from pynestml.symbols.symbol import SymbolKind
from pynestml.utils.logger import Logger, LoggingLevel
from pynestml.utils.messages import Messages
Expand Down Expand Up @@ -89,24 +92,44 @@ def visit_variable(self, node: ASTNode):
if not (isinstance(node, ASTExternalVariable) and node.get_alternate_name()):
code, message = Messages.get_no_variable_found(kernelName)
Logger.log_message(node=self.__neuron_node, code=code, message=message, log_level=LoggingLevel.ERROR)

continue

if not symbol.is_kernel():
continue

if node.get_complete_name() == kernelName:
parent = node.get_parent()
if parent is not None:
parent = node
correct = False
while parent is not None and not isinstance(parent, ASTModel):
parent = parent.get_parent()
assert parent is not None

if isinstance(parent, ASTDeclaration):
for lhs_var in parent.get_variables():
if kernelName == lhs_var.get_complete_name():
# kernel name appears on lhs of declaration, assume it is initial state
correct = True
parent = None # break out of outer loop
break

if isinstance(parent, ASTKernel):
continue
grandparent = parent.get_parent()
if grandparent is not None and isinstance(grandparent, ASTFunctionCall):
grandparent_func_name = grandparent.get_name()
if grandparent_func_name == 'convolve':
continue
code, message = Messages.get_kernel_outside_convolve(kernelName)
Logger.log_message(code=code,
message=message,
log_level=LoggingLevel.ERROR,
error_position=node.get_source_position())
# kernel name is used inside kernel definition, e.g. for a node ``g``, it appears in ``kernel g'' = -1/tau**2 * g - 2/tau * g'``
correct = True
break

if isinstance(parent, ASTFunctionCall):
func_name = parent.get_name()
if func_name == PredefinedFunctions.CONVOLVE:
# kernel name is used inside convolve call
correct = True

if not correct:
code, message = Messages.get_kernel_outside_convolve(kernelName)
Logger.log_message(code=code,
message=message,
log_level=LoggingLevel.ERROR,
error_position=node.get_source_position())


class KernelCollectingVisitor(ASTVisitor):
Expand Down
3 changes: 0 additions & 3 deletions pynestml/cocos/co_co_v_comp_exists.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ def check_co_co(cls, neuron: ASTModel):
Models which are supposed to be compartmental but do not contain
state variable called v_comp are not correct.
:param neuron: a single neuron instance.
:param after_ast_rewrite: indicates whether this coco is checked
after the code generator has done rewriting of the abstract syntax tree.
If True, checks are not as rigorous. Use False where possible.
"""
from pynestml.codegeneration.nest_compartmental_code_generator import NESTCompartmentalCodeGenerator

Expand Down
13 changes: 9 additions & 4 deletions pynestml/cocos/co_cos_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from pynestml.cocos.co_co_priorities_correctly_specified import CoCoPrioritiesCorrectlySpecified
from pynestml.meta_model.ast_model import ASTModel
from pynestml.frontend.frontend_configuration import FrontendConfiguration
from pynestml.utils.logger import Logger


class CoCosManager:
Expand Down Expand Up @@ -123,12 +124,12 @@ def check_state_variables_initialized(cls, model: ASTModel):
CoCoStateVariablesInitialized.check_co_co(model)

@classmethod
def check_variables_defined_before_usage(cls, model: ASTModel, after_ast_rewrite: bool) -> None:
def check_variables_defined_before_usage(cls, model: ASTModel) -> None:
"""
Checks that all variables are defined before being used.
:param model: a single model.
"""
CoCoAllVariablesDefined.check_co_co(model, after_ast_rewrite)
CoCoAllVariablesDefined.check_co_co(model)

@classmethod
def check_v_comp_requirement(cls, neuron: ASTModel):
Expand Down Expand Up @@ -402,17 +403,19 @@ def check_input_port_size_type(cls, model: ASTModel):
CoCoVectorInputPortsCorrectSizeType.check_co_co(model)

@classmethod
def post_symbol_table_builder_checks(cls, model: ASTModel, after_ast_rewrite: bool = False):
def check_cocos(cls, model: ASTModel, after_ast_rewrite: bool = False):
"""
Checks all context conditions.
:param model: a single model object.
"""
Logger.set_current_node(model)

cls.check_each_block_defined_at_most_once(model)
cls.check_function_defined(model)
cls.check_variables_unique_in_scope(model)
cls.check_inline_expression_not_assigned_to(model)
cls.check_state_variables_initialized(model)
cls.check_variables_defined_before_usage(model, after_ast_rewrite)
cls.check_variables_defined_before_usage(model)
if FrontendConfiguration.get_target_platform().upper() == 'NEST_COMPARTMENTAL':
# XXX: TODO: refactor this out; define a ``cocos_from_target_name()`` in the frontend instead.
cls.check_v_comp_requirement(model)
Expand Down Expand Up @@ -452,3 +455,5 @@ def post_symbol_table_builder_checks(cls, model: ASTModel, after_ast_rewrite: bo
cls.check_co_co_priorities_correctly_specified(model)
cls.check_resolution_func_legally_used(model)
cls.check_input_port_size_type(model)

Logger.set_current_node(None)
4 changes: 2 additions & 2 deletions pynestml/codegeneration/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
# along with NEST. If not, see <http://www.gnu.org/licenses/>.

from __future__ import annotations
import subprocess
import os

from typing import Any, Mapping, Optional

from abc import ABCMeta, abstractmethod
import os
import subprocess

from pynestml.exceptions.invalid_target_exception import InvalidTargetException
from pynestml.frontend.frontend_configuration import FrontendConfiguration
Expand Down
7 changes: 5 additions & 2 deletions pynestml/codegeneration/nest_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import pynestml

from pynestml.cocos.co_co_nest_synapse_delay_not_assigned_to import CoCoNESTSynapseDelayNotAssignedTo
from pynestml.cocos.co_cos_manager import CoCosManager
from pynestml.codegeneration.code_generator import CodeGenerator
from pynestml.codegeneration.code_generator_utils import CodeGeneratorUtils
from pynestml.codegeneration.nest_assignments_helper import NestAssignmentsHelper
Expand Down Expand Up @@ -374,6 +375,9 @@ def analyse_neuron(self, neuron: ASTModel) -> Tuple[Dict[str, ASTAssignment], Di
if not used_in_eq:
self.non_equations_state_variables[neuron.get_name()].append(var)

# cache state variables before symbol table update for the sake of delay variables
state_vars_before_update = neuron.get_state_symbols()

ASTUtils.remove_initial_values_for_kernels(neuron)
kernels = ASTUtils.remove_kernel_definitions_from_equations_block(neuron)
ASTUtils.update_initial_values_for_odes(neuron, [analytic_solver, numeric_solver])
Expand All @@ -388,7 +392,6 @@ def analyse_neuron(self, neuron: ASTModel) -> Tuple[Dict[str, ASTAssignment], Di
neuron = ASTUtils.add_declarations_to_internals(
neuron, self.analytic_solver[neuron.get_name()]["propagators"])

state_vars_before_update = neuron.get_state_symbols()
self.update_symbol_table(neuron)

# Update the delay parameter parameters after symbol table update
Expand Down Expand Up @@ -898,8 +901,8 @@ def update_symbol_table(self, neuron) -> None:
"""
SymbolTable.delete_model_scope(neuron.get_name())
symbol_table_visitor = ASTSymbolTableVisitor()
symbol_table_visitor.after_ast_rewrite_ = True
neuron.accept(symbol_table_visitor)
CoCosManager.check_cocos(neuron, after_ast_rewrite=True)
SymbolTable.add_model_scope(neuron.get_name(), neuron.get_scope())

def get_spike_update_expressions(self, neuron: ASTModel, kernel_buffers, solver_dicts, delta_factors) -> Tuple[Dict[str, ASTAssignment], Dict[str, ASTAssignment]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -740,8 +740,8 @@ def update_symbol_table(self, neuron, kernel_buffers):
"""
SymbolTable.delete_model_scope(neuron.get_name())
symbol_table_visitor = ASTSymbolTableVisitor()
symbol_table_visitor.after_ast_rewrite_ = True
neuron.accept(symbol_table_visitor)
CoCosManager.check_cocos(neuron, after_ast_rewrite=True)
SymbolTable.add_model_scope(neuron.get_name(), neuron.get_scope())

def _get_ast_variable(self, neuron, var_name) -> Optional[ASTVariable]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def setup_printers(self):

# GSL printers
self._gsl_variable_printer = PythonSteppingFunctionVariablePrinter(None)
print("In Python code generator: created self._gsl_variable_printer = " + str(self._gsl_variable_printer))
self._gsl_function_call_printer = PythonSteppingFunctionFunctionCallPrinter(None)
self._gsl_printer = PythonExpressionPrinter(simple_expression_printer=PythonSimpleExpressionPrinter(variable_printer=self._gsl_variable_printer,
constant_printer=self._constant_printer,
Expand Down
3 changes: 2 additions & 1 deletion pynestml/codegeneration/spinnaker_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def setup_printers(self):

# GSL printers
self._gsl_variable_printer = PythonSteppingFunctionVariablePrinter(None)
print("In Python code generator: created self._gsl_variable_printer = " + str(self._gsl_variable_printer))
self._gsl_function_call_printer = PythonSteppingFunctionFunctionCallPrinter(None)
self._gsl_printer = PythonExpressionPrinter(simple_expression_printer=SpinnakerPythonSimpleExpressionPrinter(
variable_printer=self._gsl_variable_printer,
Expand Down Expand Up @@ -216,6 +215,7 @@ def generate_code(self, models: Sequence[ASTModel]) -> None:
for model in models:
cloned_model = model.clone()
cloned_model.accept(ASTSymbolTableVisitor())
CoCosManager.check_cocos(cloned_model)
cloned_models.append(cloned_model)

self.codegen_cpp.generate_code(cloned_models)
Expand All @@ -224,6 +224,7 @@ def generate_code(self, models: Sequence[ASTModel]) -> None:
for model in models:
cloned_model = model.clone()
cloned_model.accept(ASTSymbolTableVisitor())
CoCosManager.check_cocos(cloned_model)
cloned_models.append(cloned_model)

self.codegen_py.generate_code(cloned_models)
4 changes: 2 additions & 2 deletions pynestml/frontend/frontend_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ def handle_module_name(cls, module_name):

@classmethod
def handle_target_platform(cls, target_platform: Optional[str]):
if target_platform is None or target_platform.upper() == 'NONE':
target_platform = '' # make sure `target_platform` is always a string
if target_platform is None:
target_platform = "NONE" # make sure `target_platform` is always a string

from pynestml.frontend.pynestml_frontend import get_known_targets

Expand Down
Loading

0 comments on commit 3f1e731

Please sign in to comment.