Skip to content

Disables vector arithmetics for variables #1158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions flow360/component/simulation/services.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Simulation services module."""

# pylint: disable=duplicate-code
# pylint: disable=duplicate-code, too-many-lines
import json
import os
from enum import Enum
Expand Down Expand Up @@ -825,7 +825,10 @@ def validate_expression(variables: list[dict], expressions: list[str]):
try:
expression_object = Expression(expression=expression)
result = expression_object.evaluate(raise_on_non_evaluable=False)
if np.isnan(result):
if isinstance(result, (list, np.ndarray)):
if np.isnan(result).all():
pass
elif isinstance(result, Number) and np.isnan(result):
pass
elif isinstance(result, Number):
value = result
Expand All @@ -840,6 +843,9 @@ def validate_expression(variables: list[dict], expressions: list[str]):
value = float(result[0])
else:
value = tuple(result.tolist())

# Test symbolically
expression_object.evaluate(raise_on_non_evaluable=False, force_evaluate=False)
except pd.ValidationError as err:
errors.extend(err.errors())
except Exception as err: # pylint: disable=broad-exception-caught
Expand Down
33 changes: 33 additions & 0 deletions flow360/component/simulation/user_code/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,25 @@ def validate(cls, value: Any):
AnyNumericType = Union[float, UnytArray, list]


def check_vector_arithmetic(func):
"""Decorator to check if vector arithmetic is being attempted and raise an error if so."""

def wrapper(self, other):
vector_arithmetic = False
if isinstance(other, unyt_array) and other.shape != ():
vector_arithmetic = True
elif isinstance(other, list):
vector_arithmetic = True
if vector_arithmetic:
raise ValueError(
f"Vector operation ({func.__name__} between {self.name} and {other}) not "
"supported for variables. Please write expression for each component."
)
return func(self, other)

return wrapper


class Variable(Flow360BaseModel):
"""Base class representing a symbolic variable"""

Expand All @@ -145,16 +164,19 @@ class Variable(Flow360BaseModel):

model_config = pd.ConfigDict(validate_assignment=True, extra="allow")

@check_vector_arithmetic
def __add__(self, other):
(arg, parenthesize) = _convert_argument(other)
str_arg = arg if not parenthesize else f"({arg})"
return Expression(expression=f"{self.name} + {str_arg}")

@check_vector_arithmetic
def __sub__(self, other):
(arg, parenthesize) = _convert_argument(other)
str_arg = arg if not parenthesize else f"({arg})"
return Expression(expression=f"{self.name} - {str_arg}")

@check_vector_arithmetic
def __mul__(self, other):
if isinstance(other, Number) and other == 0:
return Expression(expression="0")
Expand All @@ -163,21 +185,25 @@ def __mul__(self, other):
str_arg = arg if not parenthesize else f"({arg})"
return Expression(expression=f"{self.name} * {str_arg}")

@check_vector_arithmetic
def __truediv__(self, other):
(arg, parenthesize) = _convert_argument(other)
str_arg = arg if not parenthesize else f"({arg})"
return Expression(expression=f"{self.name} / {str_arg}")

@check_vector_arithmetic
def __floordiv__(self, other):
(arg, parenthesize) = _convert_argument(other)
str_arg = arg if not parenthesize else f"({arg})"
return Expression(expression=f"{self.name} // {str_arg}")

@check_vector_arithmetic
def __mod__(self, other):
(arg, parenthesize) = _convert_argument(other)
str_arg = arg if not parenthesize else f"({arg})"
return Expression(expression=f"{self.name} % {str_arg}")

@check_vector_arithmetic
def __pow__(self, other):
(arg, parenthesize) = _convert_argument(other)
str_arg = arg if not parenthesize else f"({arg})"
Expand All @@ -192,16 +218,19 @@ def __pos__(self):
def __abs__(self):
return Expression(expression=f"abs({self.name})")

@check_vector_arithmetic
def __radd__(self, other):
(arg, parenthesize) = _convert_argument(other)
str_arg = arg if not parenthesize else f"({arg})"
return Expression(expression=f"{str_arg} + {self.name}")

@check_vector_arithmetic
def __rsub__(self, other):
(arg, parenthesize) = _convert_argument(other)
str_arg = arg if not parenthesize else f"({arg})"
return Expression(expression=f"{str_arg} - {self.name}")

@check_vector_arithmetic
def __rmul__(self, other):
if isinstance(other, Number) and other == 0:
return Expression(expression="0")
Expand All @@ -210,21 +239,25 @@ def __rmul__(self, other):
str_arg = arg if not parenthesize else f"({arg})"
return Expression(expression=f"{str_arg} * {self.name}")

@check_vector_arithmetic
def __rtruediv__(self, other):
(arg, parenthesize) = _convert_argument(other)
str_arg = arg if not parenthesize else f"({arg})"
return Expression(expression=f"{str_arg} / {self.name}")

@check_vector_arithmetic
def __rfloordiv__(self, other):
(arg, parenthesize) = _convert_argument(other)
str_arg = arg if not parenthesize else f"({arg})"
return Expression(expression=f"{str_arg} // {self.name}")

@check_vector_arithmetic
def __rmod__(self, other):
(arg, parenthesize) = _convert_argument(other)
str_arg = arg if not parenthesize else f"({arg})"
return Expression(expression=f"{str_arg} % {self.name}")

@check_vector_arithmetic
def __rpow__(self, other):
(arg, parenthesize) = _convert_argument(other)
str_arg = arg if not parenthesize else f"({arg})"
Expand Down
31 changes: 23 additions & 8 deletions tests/simulation/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
ReferenceGeometry,
Surface,
)
from flow360.component.simulation.services import ValidationCalledBy, validate_model
from flow360.component.simulation.services import (
ValidationCalledBy,
validate_expression,
validate_model,
)
from flow360.component.simulation.translator.solver_translator import (
user_variable_to_udf,
)
Expand Down Expand Up @@ -701,6 +705,24 @@ class TestModel(Flow360BaseModel):
11,
) # Python 3.9 report error on col 11, error message is also different

with pytest.raises(
ValueError,
match=re.escape(
"Vector operation (__add__ between solution.velocity and [1 2 3] cm/ms) not supported for variables. Please write expression for each component."
),
):
UserVariable(name="x", value=solution.velocity + [1, 2, 3] * u.cm / u.ms)

errors, _, _ = validate_expression(
variables=[], expressions=["solution.velocity + [1, 2, 3] * u.cm / u.ms"]
)
assert len(errors) == 1
assert errors[0]["type"] == "value_error"
assert (
"Vector operation (__add__ between solution.velocity and [1 2 3] cm/ms) not supported for variables. Please write expression for each component."
in errors[0]["msg"]
)


def test_solver_translation():
timestepping_unsteady = Unsteady(steps=12, step_size=0.1 * u.s)
Expand Down Expand Up @@ -1078,13 +1100,6 @@ def test_udf_generator():
).in_unit(new_unit="CGS_unit_system")
assert vel_cross_vec.value.get_output_units(input_params=params) == u.cm**2 / u.s

# DOES NOT WORK
# vel_plus_vec = UserVariable(
# name="vel_cross_vec", value=solution.velocity + [1, 2, 3] * u.cm / u.ms
# ).in_unit(new_unit="cm/s")
# result = user_variable_to_udf(vel_plus_vec, input_params=params)
# print("4>>> result.expression", result.expression)


def test_project_variables():
aaa = UserVariable(name="aaa", value=solution.velocity + 12 * u.m / u.s)
Expand Down