diff --git a/src/andromede/expression/visitor.py b/src/andromede/expression/visitor.py index 351be9e5..02d54651 100644 --- a/src/andromede/expression/visitor.py +++ b/src/andromede/expression/visitor.py @@ -133,45 +133,31 @@ def visit(root: ExpressionNode, visitor: ExpressionVisitor[T]) -> T: """ Utility method to dispatch calls to the right method of a visitor. """ - if isinstance(root, LiteralNode): - return visitor.literal(root) - elif isinstance(root, NegationNode): - return visitor.negation(root) - elif isinstance(root, VariableNode): - return visitor.variable(root) - elif isinstance(root, ParameterNode): - return visitor.parameter(root) - elif isinstance(root, ComponentParameterNode): - return visitor.comp_parameter(root) - elif isinstance(root, ComponentVariableNode): - return visitor.comp_variable(root) - elif isinstance(root, ProblemParameterNode): - return visitor.pb_parameter(root) - elif isinstance(root, ProblemVariableNode): - return visitor.pb_variable(root) - elif isinstance(root, AdditionNode): - return visitor.addition(root) - elif isinstance(root, MultiplicationNode): - return visitor.multiplication(root) - elif isinstance(root, DivisionNode): - return visitor.division(root) - elif isinstance(root, ComparisonNode): - return visitor.comparison(root) - elif isinstance(root, TimeShiftNode): - return visitor.time_shift(root) - elif isinstance(root, TimeEvalNode): - return visitor.time_eval(root) - elif isinstance(root, TimeSumNode): - return visitor.time_sum(root) - elif isinstance(root, AllTimeSumNode): - return visitor.all_time_sum(root) - elif isinstance(root, ScenarioOperatorNode): - return visitor.scenario_operator(root) - elif isinstance(root, PortFieldNode): - return visitor.port_field(root) - elif isinstance(root, PortFieldAggregatorNode): - return visitor.port_field_aggregator(root) - raise ValueError(f"Unknown expression node type {root.__class__}") + TYPES = { + LiteralNode: "literal", + NegationNode: "negation", + VariableNode: "variable", + ParameterNode: "parameter", + ComponentParameterNode: "comp_parameter", + ComponentVariableNode: "comp_variable", + ProblemParameterNode: "pb_parameter", + ProblemVariableNode: "pb_variable", + AdditionNode: "addition", + MultiplicationNode: "multiplication", + DivisionNode: "division", + ComparisonNode: "comparison", + TimeShiftNode: "time_shift", + TimeEvalNode: "time_eval", + TimeSumNode: "time_sum", + AllTimeSumNode: "all_time_sum", + ScenarioOperatorNode: "scenario_operator", + PortFieldNode: "port_field", + PortFieldAggregatorNode: "port_field_aggregator", + } + if type(root) in TYPES: + return getattr(visitor, TYPES[type(root)])(root) + else: + raise ValueError(f"Unknown expression node type {root.__class__}") class SupportsOperations(Protocol[T]):