Skip to content

Expression validation bundle #1181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: expressions
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions flow360/component/simulation/blueprint/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,8 @@ def user_variable_names(self):
def clear(self):
"""Clear user variables from the context."""
self._values = {name: value for name, value in self._values.items() if "." in name}

@property
def registered_names(self):
"""Show the registered names in the context."""
return list(self._values.keys())
138 changes: 127 additions & 11 deletions flow360/component/simulation/user_code/core/types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=too-many-lines
"""This module allows users to write serializable, evaluable symbolic code for use in simulation params"""

from __future__ import annotations
Expand Down Expand Up @@ -405,17 +406,58 @@ class UserVariable(Variable):
"""Class representing a user-defined symbolic variable"""

name: str = pd.Field(frozen=True)

type_name: Literal["UserVariable"] = pd.Field("UserVariable", frozen=True)

@pd.field_validator("name", mode="after")
@classmethod
def check_unscoped_name(cls, v):
"""Ensure that the variable name is not scoped. Only solver side variables can be scoped."""
if "." in v:
def check_valid_user_variable_name(cls, v):
"""Validate a variable identifier (ASCII only)."""
# Partial list of C++ keywords; extend as needed
RESERVED_SYNTAX_KEYWORDS = { # pylint:disable=invalid-name
"int",
"double",
"float",
"long",
"short",
"char",
"bool",
"void",
"class",
"for",
"while",
"if",
"else",
"return",
"namespace",
"template",
"typename",
"constexpr",
"virtual",
}
if not v:
raise ValueError("Identifier cannot be empty.")

# 2) First character must be letter or underscore
if not re.match(r"^[A-Za-z_]", v):
raise ValueError("Identifier must start with a letter (A-Z/a-z) or underscore (_).")

# 3) All characters must be letters, digits, or underscore
if re.search(r"[^A-Za-z0-9_]", v):
raise ValueError(
"User variable name cannot contain dots (scoped variables not supported)."
"Identifier can only contain letters, digits (0-9), or underscore (_)."
)

# 4) Not a C++ keyword
if v in RESERVED_SYNTAX_KEYWORDS:
raise ValueError(f"'{v}' is a reserved keyword.")

# 5) existing variable name:
solver_side_names = {
item.split(".")[-1] for item in default_context.registered_names if "." in item
}
if v in solver_side_names:
raise ValueError(f"'{v}' is a reserved solver side variable name.")

return v

def __hash__(self):
Expand All @@ -424,10 +466,17 @@ def __hash__(self):
"""
return hash(self.model_dump_json())

def in_unit(self, new_unit: Union[str, Unit] = None):
def in_units(
self,
new_unit: Union[
str, Literal["SI_unit_system", "CGS_unit_system", "Imperial_unit_system"], Unit
] = None,
):
"""Requesting the output of the variable to be in the given (new_unit) units."""
if isinstance(new_unit, Unit):
new_unit = str(new_unit)
if not isinstance(self.value, Expression):
raise ValueError("Cannot set output units for non expression value.")
self.value.output_units = new_unit
return self

Expand All @@ -448,7 +497,13 @@ def update_context(self):
default_context.set_alias(self.name, self.solver_name)
return self

def in_unit(self, new_name: str, new_unit: Union[str, Unit] = None):
def in_units(
self,
new_name: str,
new_unit: Union[
str, Literal["SI_unit_system", "CGS_unit_system", "Imperial_unit_system"], Unit
] = None,
):
"""
Return a UserVariable that will generate results in the new_unit.
If new_unit is not specified then the unit will be determined by the unit system.
Expand Down Expand Up @@ -526,6 +581,26 @@ def remove_leading_and_trailing_whitespace(cls, value: str) -> str:
"""Remove leading and trailing whitespace from the expression"""
return value.strip()

@pd.model_validator(mode="after")
def check_output_units_matches_dimensionality(self) -> str:
"""Check that the output units have the same dimensionality as the expression"""
print(f"self.output_units: {self.output_units}")
if not self.output_units:
return self
if self.output_units in ("SI_unit_system", "CGS_unit_system", "Imperial_unit_system"):
return self
output_units_dimensionality = u.Unit(self.output_units).dimensions
expression_dimensionality = self.dimensionality
print(f"output_units_dimensionality: {output_units_dimensionality}")
print(f"expression_dimensionality: {expression_dimensionality}")
if output_units_dimensionality != expression_dimensionality:
raise ValueError(
f"Output units '{self.output_units}' have different dimensionality "
f"{output_units_dimensionality} than the expression {expression_dimensionality}."
)

return self

def evaluate(
self,
context: EvaluationContext = None,
Expand Down Expand Up @@ -726,8 +801,15 @@ def __eq__(self, other):
def dimensionality(self):
"""The physical dimensionality of the expression."""
value = self.evaluate(raise_on_non_evaluable=False, force_evaluate=True)
assert isinstance(value, (unyt_array, unyt_quantity))
return value.units.dimensions
assert isinstance(
value, (unyt_array, unyt_quantity, list, Number)
), "Non unyt array so no dimensionality"
if isinstance(value, (unyt_array, unyt_quantity)):
return value.units.dimensions
if isinstance(value, list):
_check_list_items_are_same_dimensionality(value)
return value[0].units.dimensions
return None

@property
def length(self):
Expand Down Expand Up @@ -791,6 +873,33 @@ def get_unit_from_unit_system(expression: Expression, unit_system_name: str):
T = TypeVar("T")


def _check_list_items_are_same_dimensionality(value: list) -> bool:
print(f"Checking list items are same dimensionality: {value}")
print(
"results for the ifs:",
all(isinstance(item, Expression) for item in value),
all(isinstance(item, unyt_quantity) for item in value),
any(isinstance(item, Number) for item in value)
and any(isinstance(item, unyt_quantity) for item in value),
)
if all(isinstance(item, Expression) for item in value):
_check_list_items_are_same_dimensionality(
[item.evaluate(raise_on_non_evaluable=False, force_evaluate=True) for item in value]
)
return
if all(isinstance(item, unyt_quantity) for item in value):
# ensure all items have the same dimensionality
if not all(item.units.dimensions == value[0].units.dimensions for item in value):
raise ValueError("All items in the list must have the same dimensionality.")
return
# Also raise when some elements is Number and others are unyt_quantity
if any(isinstance(item, Number) for item in value) and any(
isinstance(item, unyt_quantity) for item in value
):
raise ValueError("List must contain only all unyt_quantities or all numbers.")
return


class ValueOrExpression(Expression, Generic[T]):
"""Model accepting both value and expressions"""

Expand All @@ -811,13 +920,20 @@ def _deserialize(value) -> Self:
if value.type_name == "number":
if value.units is not None:
# unyt objects
return unyt_array(value.value, value.units)
return unyt_array(value.value, value.units, dtype=np.float64)
return value.value
if value.type_name == "expression":
return expr_type(expression=value.expression, output_units=value.output_units)
except Exception: # pylint:disable=broad-exception-caught
pass

# Handle list of unyt_quantities:
if isinstance(value, list):
# Only checking when list[unyt_quantity]
_check_list_items_are_same_dimensionality(value)
if all(isinstance(item, (unyt_quantity, Number)) for item in value):
# try limiting the number of types we need to handle
return unyt_array(value, dtype=np.float64)
return value

def _serializer(value, info) -> dict:
Expand All @@ -834,7 +950,7 @@ def _serializer(value, info) -> dict:
if isinstance(evaluated, list):
# May result from Expression which is actually a list of expressions
try:
evaluated = u.unyt_array(evaluated)
evaluated = u.unyt_array(evaluated, dtype=np.float64)
except u.exceptions.IterableUnitCoercionError:
# Inconsistent units for components of list
pass
Expand Down
77 changes: 66 additions & 11 deletions tests/simulation/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,7 +897,7 @@ def test_cross_function_use_case():

print("\n3 (Units defined in components)\n")
a.value = math.cross([3 * u.m, 2 * u.m, 1 * u.m], [2 * u.m, 2 * u.m, 1 * u.m])
assert a.value == [0 * u.m * u.m, -1 * u.m * u.m, 2 * u.m * u.m]
assert all(a.value == [0, -1, 2] * u.m * u.m)

print("\n4 Serialized version\n")
a.value = "math.cross([3, 2, 1] * u.m, solution.coordinate)"
Expand Down Expand Up @@ -1015,7 +1015,7 @@ def test_to_file_from_file_expression(
outputs=[
VolumeOutput(
output_fields=[
solution.mut.in_unit(new_name="mut_in_SI", new_unit="cm**2/min"),
solution.mut.in_units(new_name="mut_in_SI", new_unit="g/cm/min"),
constant_variable,
constant_array,
constant_unyt_quantity,
Expand Down Expand Up @@ -1043,14 +1043,15 @@ def test_udf_generator():
)
# Scalar output
result = user_variable_to_udf(
solution.mut.in_unit(new_name="mut_in_km", new_unit="km**2/s"), input_params=params
solution.mut.in_units(new_name="mut_in_km", new_unit="kg/km/s"), input_params=params
)
# velocity scale = 100 m/s, length scale = 10m, mut_scale = 1000 m**2/s -> 0.01 *km**2/s
assert result.expression == "mut_in_km = (mut * 0.001);"
# velocity scale = 100 m/s, length scale = 10m, density scale = 1000 kg/m**3
# mut_scale = Rho*L*V -> 1000*10*100 * kg/m/s == 1000*10*100*1000 * kg/km/s
assert result.expression == "mut_in_km = (mut * 1000000000.0);"

# Vector output
result = user_variable_to_udf(
solution.velocity.in_unit(new_name="velocity_in_SI", new_unit="m/s"), input_params=params
solution.velocity.in_units(new_name="velocity_in_SI", new_unit="m/s"), input_params=params
)
# velocity scale = 100 m/s,
assert (
Expand All @@ -1060,16 +1061,17 @@ def test_udf_generator():

vel_cross_vec = UserVariable(
name="vel_cross_vec", value=math.cross(solution.velocity, [1, 2, 3] * u.cm)
).in_unit(new_unit="m*km/s/s")
).in_units(new_unit="m*km/s")
result = user_variable_to_udf(vel_cross_vec, input_params=params)
# velocity scale = 100 m/s, length scale = 10m, scale = 1000m**2/s-->1 km*m/s
assert (
result.expression
== "double ___velocity[3];___velocity[0] = primitiveVars[1] * velocityScale;___velocity[1] = primitiveVars[2] * velocityScale;___velocity[2] = primitiveVars[3] * velocityScale;vel_cross_vec[0] = ((((___velocity[1] * 3) * 0.001) - ((___velocity[2] * 2) * 0.001)) * 10.0); vel_cross_vec[1] = ((((___velocity[2] * 1) * 0.001) - ((___velocity[0] * 3) * 0.001)) * 10.0); vel_cross_vec[2] = ((((___velocity[0] * 2) * 0.001) - ((___velocity[1] * 1) * 0.001)) * 10.0);"
== "double ___velocity[3];___velocity[0] = primitiveVars[1] * velocityScale;___velocity[1] = primitiveVars[2] * velocityScale;___velocity[2] = primitiveVars[3] * velocityScale;vel_cross_vec[0] = ((((___velocity[1] * 3) * 0.001) - ((___velocity[2] * 2) * 0.001)) * 1.0); vel_cross_vec[1] = ((((___velocity[2] * 1) * 0.001) - ((___velocity[0] * 3) * 0.001)) * 1.0); vel_cross_vec[2] = ((((___velocity[0] * 2) * 0.001) - ((___velocity[1] * 1) * 0.001)) * 1.0);"
)

vel_cross_vec = UserVariable(
name="vel_cross_vec", value=math.cross(solution.velocity, [1, 2, 3] * u.cm)
).in_unit(new_unit="CGS_unit_system")
).in_units(new_unit="CGS_unit_system")
assert vel_cross_vec.value.get_output_units(input_params=params) == u.cm**2 / u.s


Expand All @@ -1078,7 +1080,7 @@ def test_project_variables_serialization():
aaa = UserVariable(
name="aaa", value=[solution.velocity[0] + ccc, solution.velocity[1], solution.velocity[2]]
)
bbb = UserVariable(name="bbb", value=[aaa[0] + 14 * u.m / u.s, aaa[1], aaa[2]]).in_unit(
bbb = UserVariable(name="bbb", value=[aaa[0] + 14 * u.m / u.s, aaa[1], aaa[2]]).in_units(
new_unit="km/ms"
)

Expand Down Expand Up @@ -1146,14 +1148,67 @@ def test_project_variables_deserialization():


def test_overwriting_project_variables():
UserVariable(name="a", value=1)
a = UserVariable(name="a", value=1)

with pytest.raises(
ValueError,
match="Redeclaring user variable a with new value: 2.0. Previous value: 1.0",
):
UserVariable(name="a", value=2)

a.value = 2
assert a.value == 2


def test_unique_dimensionality():
with pytest.raises(
ValueError, match="All items in the list must have the same dimensionality."
):
UserVariable(name="a", value=[1 * u.m, 1 * u.s])

with pytest.raises(
ValueError, match="List must contain only all unyt_quantities or all numbers."
):
UserVariable(name="a", value=[1.0 * u.m, 1.0])

a = UserVariable(name="a", value=[1.0 * u.m, 1.0 * u.mm])
assert all(a.value == [1.0, 0.001] * u.m)


@pytest.mark.parametrize(
"bad_name, expected_msg",
[
("", "Identifier cannot be empty."),
("1stPlace", "Identifier must start with a letter (A-Z/a-z) or underscore (_)."),
("bad-name", "Identifier can only contain letters, digits (0-9), or underscore (_)."),
("has space", "Identifier can only contain letters, digits (0-9), or underscore (_)."),
(" leading", "Identifier must start with a letter (A-Z/a-z) or underscore (_)."),
("trailing ", "Identifier can only contain letters, digits (0-9), or underscore (_)."),
("tab\tname", "Identifier can only contain letters, digits (0-9), or underscore (_)."),
("new\nline", "Identifier can only contain letters, digits (0-9), or underscore (_)."),
("name$", "Identifier can only contain letters, digits (0-9), or underscore (_)."),
("class", "'class' is a reserved keyword."),
("namespace", "'namespace' is a reserved keyword."),
("template", "'template' is a reserved keyword."),
("temperature", "'temperature' is a reserved solver side variable name."),
("velocity", "'velocity' is a reserved solver side variable name."),
],
)
def test_invalid_names_raise(bad_name, expected_msg):
with pytest.raises(ValueError, match=re.escape(expected_msg)):
UserVariable(name=bad_name, value=0)


def test_output_units_dimensionality():
with pytest.raises(
ValueError,
match=re.escape(
"Output units 'ms' have different dimensionality (time) than the expression (length)."
),
):
a = UserVariable(name="a", value="1 * u.m")
a.in_units(new_unit="ms")


def test_whitelisted_callables():
def get_user_variable_names(module):
Expand Down
2 changes: 1 addition & 1 deletion tests/simulation/translator/test_output_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@

@pytest.fixture()
def vel_in_km_per_hr():
return solution.velocity.in_unit(new_name="velocity_in_km_per_hr", new_unit=u.km / u.hr)
return solution.velocity.in_units(new_name="velocity_in_km_per_hr", new_unit=u.km / u.hr)


@pytest.fixture()
Expand Down
4 changes: 3 additions & 1 deletion tests/simulation/translator/test_solver_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,9 @@ def test_param_with_user_variables():
output_fields=[
solution.Mach,
solution.velocity,
UserVariable(name="uuu", value=solution.velocity).in_unit(new_unit="km/ms"),
UserVariable(name="uuu", value=solution.velocity).in_units(
new_unit="km/ms"
),
my_var,
],
)
Expand Down