Skip to content

Update cached model to store all constructor function arguments #318

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions flow360/component/simulation/framework/cached_model_base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,35 @@
import abc
from typing import Any, Dict
import inspect
from functools import wraps
from typing import Any, Callable, Dict

import pydantic as pd

from flow360.component.simulation.framework.base_model import Flow360BaseModel
from flow360.component.types import TYPE_TAG_STR


class CachedModelBase(Flow360BaseModel, metaclass=abc.ABCMeta):
@classmethod
def model_constructor(cls, func: Callable) -> Callable:
@classmethod
@wraps(func)
def wrapper(cls, *args, **kwargs):
sig = inspect.signature(func)
result = func(cls, *args, **kwargs)
defaults = {
k: v.default
for k, v in sig.parameters.items()
if v.default is not inspect.Parameter.empty
}
result._cached = result.__annotations__["_cached"](
**{**result._cached.model_dump(), **defaults, **kwargs}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if **result._cached.model_dump(), **defaults, **kwargs have same keyword args?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cache will be updated. This order should be correct: existing cache < defaults < kwargs

)
result._cached.constructor = func.__name__
return result

return wrapper

def __init__(self, **data):
cached = data.pop("_cached", None)
super().__init__(**data)
Expand All @@ -15,9 +38,17 @@ def __init__(self, **data):
self._cached = self.__annotations__["_cached"].model_validate(cached)
except:
pass
else:
defaults = {name: field.default for name, field in self.model_fields.items()}
defaults.pop(TYPE_TAG_STR)
self._cached = self.__annotations__["_cached"](
**{**defaults, **data}, constructor="default"
)

@pd.model_serializer(mode="wrap")
def serialize_model(self, handler) -> Dict[str, Any]:
serialize_self = handler(self)
serialize_self["_cached"] = self._cached.model_dump() if self._cached else None
serialize_self["_cached"] = (
self._cached.model_dump(exclude_none=True) if self._cached else None
)
return serialize_self
62 changes: 46 additions & 16 deletions flow360/component/simulation/operating_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@ class ThermalStateCache(Flow360BaseModel):
"""[INTERNAL] Cache for thermal state inputs"""

# pylint: disable=no-member
constructor: Optional[str] = None
altitude: Optional[LengthType.Positive] = None
temperature_offset: Optional[TemperatureType] = None
temperature: Optional[TemperatureType.Positive] = None
density: Optional[DensityType.Positive] = None
material: Optional[FluidMaterialTypes] = None


class ThermalState(CachedModelBase):
Expand All @@ -55,7 +59,8 @@ class ThermalState(CachedModelBase):
material: FluidMaterialTypes = pd.Field(Air(), frozen=True)
_cached: ThermalStateCache = ThermalStateCache()

@classmethod
# pylint: disable=no-self-argument, not-callable, unused-argument
@CachedModelBase.model_constructor
@pd.validate_call
def from_standard_atmosphere(
cls, altitude: LengthType.Positive = 0 * u.m, temperature_offset: TemperatureType = 0 * u.K
Expand All @@ -71,7 +76,6 @@ def from_standard_atmosphere(
temperature=temperature,
material=Air(),
)
state._cached = ThermalStateCache(altitude=altitude, temperature_offset=temperature_offset)

return state

Expand Down Expand Up @@ -119,15 +123,39 @@ def mu_ref(self, mesh_unit: LengthType.Positive) -> pd.PositiveFloat:
return (self.dynamic_viscosity / (self.speed_of_sound * self.density * mesh_unit)).v.item()


class GenericReferenceCondition(Flow360BaseModel):
class GenericReferenceConditionCache(Flow360BaseModel):
"""[INTERNAL] Cache for GenericReferenceCondition inputs"""

constructor: Optional[str] = None
velocity_magnitude: Optional[VelocityType.Positive] = None
thermal_state: Optional[ThermalState] = None
mach: Optional[pd.PositiveFloat] = None


class AerospaceConditionCache(Flow360BaseModel):
"""[INTERNAL] Cache for AerospaceCondition inputs"""

constructor: Optional[str] = None
alpha: Optional[AngleType] = None
beta: Optional[AngleType] = None
reference_velocity_magnitude: Optional[VelocityType.Positive] = None
velocity_magnitude: Optional[VelocityType.NonNegative] = None
thermal_state: Optional[ThermalState] = pd.Field(None, alias="atmosphere")
mach: Optional[pd.NonNegativeFloat] = None
reference_mach: Optional[pd.PositiveFloat] = None


class GenericReferenceCondition(CachedModelBase):
"""
Operating condition defines the physical (non-geometrical) reference values for the problem.
"""

velocity_magnitude: VelocityType.Positive
thermal_state: ThermalState = ThermalState()
_cached: GenericReferenceConditionCache = GenericReferenceConditionCache()

@classmethod
# pylint: disable=no-self-argument, not-callable
@CachedModelBase.model_constructor
@pd.validate_call
def from_mach(
cls,
Expand All @@ -144,61 +172,63 @@ def mach(self) -> pd.PositiveFloat:
return self.velocity_magnitude / self.thermal_state.speed_of_sound


class AerospaceCondition(Flow360BaseModel):
class AerospaceCondition(CachedModelBase):
"""A specialized GenericReferenceCondition for aerospace applications."""

# pylint: disable=fixme
# TODO: add units for angles
# TODO: valildate reference_velocity_magnitude defined if velocity_magnitude=0
alpha: AngleType = 0 * u.deg
beta: AngleType = 0 * u.deg
velocity_magnitude: VelocityType.NonNegative
thermal_state: ThermalState = pd.Field(ThermalState(), alias="atmosphere")
reference_velocity_magnitude: Optional[VelocityType.Positive] = None
_cached: AerospaceConditionCache = AerospaceConditionCache()

# pylint: disable=too-many-arguments
@classmethod
# pylint: disable=too-many-arguments, no-self-argument, not-callable
@CachedModelBase.model_constructor
@pd.validate_call
def from_mach(
cls,
mach: pd.PositiveFloat,
alpha: AngleType = 0 * u.deg,
beta: AngleType = 0 * u.deg,
atmosphere: ThermalState = ThermalState(),
thermal_state: ThermalState = ThermalState(),
reference_mach: Optional[pd.PositiveFloat] = None,
):
"""Constructs a `AerospaceCondition` from Mach number and thermal state."""

velocity_magnitude = mach * atmosphere.speed_of_sound
velocity_magnitude = mach * thermal_state.speed_of_sound

reference_velocity_magnitude = (
reference_mach * atmosphere.speed_of_sound if reference_mach else None
reference_mach * thermal_state.speed_of_sound if reference_mach else None
)
return cls(
velocity_magnitude=velocity_magnitude,
alpha=alpha,
beta=beta,
atmosphere=atmosphere,
thermal_state=thermal_state,
reference_velocity_magnitude=reference_velocity_magnitude,
)

@classmethod
# pylint: disable=no-self-argument, not-callable
@CachedModelBase.model_constructor
@pd.validate_call
def from_stationary(
cls,
reference_velocity_magnitude: VelocityType.Positive,
atmosphere: ThermalState = ThermalState(),
thermal_state: ThermalState = ThermalState(),
):
"""Constructs a `AerospaceCondition` for stationary conditions."""
return cls(
velocity_magnitude=0 * u.m / u.s,
atmosphere=atmosphere,
thermal_state=thermal_state,
reference_velocity_magnitude=reference_velocity_magnitude,
)

@property
def mach(self) -> pd.PositiveFloat:
"""Computes Mach number."""
return self.velocity_magnitude / self.atmosphere.speed_of_sound
return self.velocity_magnitude / self.thermal_state.speed_of_sound

# pylint: disable=fixme
# TODO: Add after model validation that reference_velocity_magnitude is set when velocity_magnitude is 0
Expand Down
33 changes: 28 additions & 5 deletions tests/simulation/framework/test_cached_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import json
import os
import tempfile
from typing import Optional

import pydantic as pd
Expand All @@ -15,16 +18,21 @@


class TempThermalStateCache(Flow360BaseModel):
constructor: Optional[str] = None
altitude: Optional[LengthType.Positive] = None
temperature_offset: Optional[TemperatureType] = None
some_value: Optional[float] = None
temperature: Optional[TemperatureType.Positive] = None
density: Optional[DensityType.Positive] = None


class TempThermalState(CachedModelBase):
temperature: TemperatureType.Positive = pd.Field(288.15 * u.K, frozen=True)
density: DensityType.Positive = pd.Field(1.225 * u.kg / u.m**3, frozen=True)
some_value: float = 0.1
_cached: TempThermalStateCache = TempThermalStateCache()

@classmethod
@CachedModelBase.model_constructor
def from_standard_atmosphere(
cls, altitude: LengthType.Positive = 0 * u.m, temperature_offset: TemperatureType = 0 * u.K
):
Expand All @@ -35,9 +43,7 @@ def from_standard_atmosphere(
density=density,
temperature=temperature,
)
state._cached = TempThermalStateCache(
altitude=altitude, temperature_offset=temperature_offset
)

return state

@property
Expand All @@ -56,4 +62,21 @@ def test_cache_model():
some_value=1230,
thermal_state=TempThermalState.from_standard_atmosphere(altitude=100 * u.m),
)
to_file_from_file_test(operating_condition)

with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as temp_file:
temp_file_name = temp_file.name

try:
operating_condition.to_file(temp_file_name)
with open(temp_file_name) as fp:
model_dict = json.load(fp)
assert model_dict["thermal_state"]["_cached"]["some_value"] == 0.1
assert (
model_dict["thermal_state"]["_cached"]["constructor"] == "from_standard_atmosphere"
)
loaded_model = TempOperatingCondition(**model_dict)
assert loaded_model == operating_condition
assert loaded_model.thermal_state._cached.altitude == 100 * u.m
assert loaded_model.thermal_state._cached.temperature_offset == 0 * u.K
finally:
os.remove(temp_file_name)
2 changes: 1 addition & 1 deletion tests/simulation/params/test_simulation_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_the_param():
mach=0.8,
alpha=30 * u.deg,
beta=20 * u.deg,
atmosphere=ThermalState(temperature=300 * u.K, density=1 * u.g / u.cm**3),
thermal_state=ThermalState(temperature=300 * u.K, density=1 * u.g / u.cm**3),
reference_mach=0.5,
),
models=[
Expand Down
Loading