Skip to content

Added util function to get the unit from expression #1157

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 2 additions & 21 deletions flow360/component/simulation/translator/solver_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# pylint: disable=too-many-lines
from numbers import Number
from typing import Literal, Type, Union
from typing import Type, Union

import numpy as np
import unyt as u
Expand Down Expand Up @@ -571,28 +571,9 @@ def _compute_coefficient_and_offset(source_unit: u.Unit, target_unit: u.Unit):

return coefficient, offset

def _get_output_unit(expression: Expression, input_params: SimulationParams):
if not expression.output_units:
# Derive the default output unit based on the value's dimensionality and current unit system
current_unit_system_name: Literal["SI", "Imperial", "CGS"] = (
input_params.unit_system.name
)
numerical_value = expression.evaluate(raise_on_non_evaluable=False, force_evaluate=True)
if not isinstance(numerical_value, (u.unyt_array, u.unyt_quantity)):
# Pure dimensionless constant
return None
if current_unit_system_name == "SI":
return numerical_value.in_base("mks").units
if current_unit_system_name == "Imperial":
return numerical_value.in_base("imperial").units
if current_unit_system_name == "CGS":
return numerical_value.in_base("cgs").units

return u.Unit(expression.output_units)

expression: Expression = variable.value

requested_unit: Union[u.Unit, None] = _get_output_unit(expression, input_params)
requested_unit: Union[u.Unit, None] = expression.get_output_units(input_params=input_params)
if requested_unit is None:
# Number constant output requested
coefficient = 1
Expand Down
52 changes: 51 additions & 1 deletion flow360/component/simulation/user_code/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np
import pydantic as pd
import unyt as u
from pydantic import BeforeValidator, Discriminator, PlainSerializer, Tag
from pydantic_core import InitErrorDetails, core_schema
from typing_extensions import Self
Expand Down Expand Up @@ -289,8 +290,10 @@ def __hash__(self):
"""
return hash(self.model_dump_json())

def in_unit(self, new_unit: str = None):
def in_unit(self, new_unit: Union[str, Unit] = None):
"""Requesting the output of the variable to be in the given (new_unit) units."""
if isinstance(new_unit, Unit):
new_unit = str(new_unit)
self.value.output_units = new_unit
return self

Expand Down Expand Up @@ -579,6 +582,53 @@ def length(self):
return len(value)
return 1 if isinstance(value, unyt_quantity) else value.shape[0]

def get_output_units(self, input_params=None):
"""
Get the output units of the expression.

- If self.output_units is None, derive the default output unit based on the
value's dimensionality and current unit system.

- If self.output_units is valid u.Unit string, deserialize it and return it.

- If self.output_units is valid unit system name, derive the default output
unit based on the value's dimensionality and the **given** unit system.

- If expression is a number constant, return None.

- Else raise ValueError.
"""

def get_unit_from_unit_system(expression: Expression, unit_system_name: str):
"""Derive the default output unit based on the value's dimensionality and current unit system"""
numerical_value = expression.evaluate(raise_on_non_evaluable=False, force_evaluate=True)
if not isinstance(numerical_value, (u.unyt_array, u.unyt_quantity, list)):
# Pure dimensionless constant
return None
if isinstance(numerical_value, list):
numerical_value = numerical_value[0]

if unit_system_name in ("SI", "SI_unit_system"):
return numerical_value.in_base("mks").units
if unit_system_name in ("Imperial", "Imperial_unit_system"):
return numerical_value.in_base("imperial").units
if unit_system_name in ("CGS", "CGS_unit_system"):
return numerical_value.in_base("cgs").units
raise ValueError(f"[Internal] Invalid unit system: {unit_system_name}")

try:
return u.Unit(self.output_units)
except u.exceptions.UnitParseError as e:
if input_params is None:
raise ValueError(
"[Internal] input_params required when output_units is not valid u.Unit string"
) from e
if not self.output_units:
unit_system_name: Literal["SI", "Imperial", "CGS"] = input_params.unit_system.name
else:
unit_system_name = self.output_units
return get_unit_from_unit_system(self, unit_system_name)


T = TypeVar("T")

Expand Down
5 changes: 5 additions & 0 deletions tests/simulation/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,6 +1073,11 @@ def test_udf_generator():
== "vel_cross_vec[0] = ((((solution.velocity[1] * 3) * 0.001) - ((solution.velocity[2] * 2) * 0.001)) * 10.0); vel_cross_vec[1] = ((((solution.velocity[2] * 1) * 0.001) - ((solution.velocity[0] * 3) * 0.001)) * 10.0); vel_cross_vec[2] = ((((solution.velocity[0] * 2) * 0.001) - ((solution.velocity[1] * 1) * 0.001)) * 10.0);"
)

vel_cross_vec = UserVariable(
name="vel_cross_vec", value=math.cross(solution.velocity, [1, 2, 3] * u.cm)
).in_unit(new_unit="CGS_unit_system")
assert vel_cross_vec.value.get_output_units(input_params=params) == u.cm**2 / u.s

# DOES NOT WORK
# vel_plus_vec = UserVariable(
# name="vel_cross_vec", value=solution.velocity + [1, 2, 3] * u.cm / u.ms
Expand Down