Skip to content

Implement the framework for entities. #258

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -325,4 +325,7 @@ $RECYCLE.BIN/

tmp/

/.vscode
/.vscode

# test residual
flow360/examples/cylinder/flow360mesh.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

import hashlib
import json
from typing import List, Literal, Optional
from typing import Literal, Optional, Union, get_args

import pydantic as pd
import rich
import yaml
from pydantic import ConfigDict
from pydantic.fields import FieldInfo

from flow360.component.simulation.framework.entity_registry import EntityRegistry
from flow360.component.types import TYPE_TAG_STR
from flow360.error_messages import do_not_modify_file_manually_msg
from flow360.exceptions import Flow360FileError
Expand Down Expand Up @@ -39,14 +39,23 @@ def __init__(self, filename: str = None, **kwargs):

@classmethod
def _handle_dict(cls, **kwargs):
"""Handle dictionary input for the model."""
model_dict = kwargs
model_dict = cls._handle_dict_with_hash(model_dict)
return model_dict

@classmethod
def _handle_file(cls, filename: str = None, **kwargs):
"""Handle file input for the model.

Parameters
----------
filename : str
Full path to the .json or .yaml file to load the :class:`Flow360BaseModel` from.
**kwargs
Keyword arguments to be passed to the model."""
if filename is not None:
return cls.dict_from_file(filename=filename)
return cls._dict_from_file(filename=filename)
return kwargs

@classmethod
Expand All @@ -67,7 +76,7 @@ def __pydantic_init_subclass__(cls, **kwargs) -> None:
"""
model_config = ConfigDict(
##:: Pydantic kwargs
arbitrary_types_allowed=True,
arbitrary_types_allowed=True, # ?
extra="forbid",
frozen=False,
populate_by_name=True,
Expand Down Expand Up @@ -155,8 +164,7 @@ def copy(self, update=None, **kwargs) -> Flow360BaseModel:
"""Copy a Flow360BaseModel. With ``deep=True`` as default."""
if "deep" in kwargs and kwargs["deep"] is False:
raise ValueError("Can't do shallow copy of component, set `deep=True` in copy().")
kwargs.update({"deep": True})
new_copy = pd.BaseModel.model_copy(self, update=update, **kwargs)
new_copy = pd.BaseModel.model_copy(self, update=update, deep=True, **kwargs)
data = new_copy.model_dump()
return self.model_validate(data)

Expand Down Expand Up @@ -195,7 +203,7 @@ def from_file(cls, filename: str) -> Flow360BaseModel:
return cls(filename=filename)

@classmethod
def dict_from_file(cls, filename: str) -> dict:
def _dict_from_file(cls, filename: str) -> dict:
"""Loads a dictionary containing the model from a .json or .yaml file.

Parameters
Expand Down Expand Up @@ -263,7 +271,7 @@ def from_json(cls, filename: str, **parse_obj_kwargs) -> Flow360BaseModel:
-------
>>> params = Flow360BaseModel.from_json(filename='folder/flow360.json') # doctest: +SKIP
"""
model_dict = cls.dict_from_file(filename=filename)
model_dict = cls._dict_from_file(filename=filename)
return cls.model_validate(model_dict, **parse_obj_kwargs)

@classmethod
Expand Down Expand Up @@ -327,7 +335,7 @@ def from_yaml(cls, filename: str, **parse_obj_kwargs) -> Flow360BaseModel:
-------
>>> params = Flow360BaseModel.from_yaml(filename='folder/flow360.yaml') # doctest: +SKIP
"""
model_dict = cls.dict_from_file(filename=filename)
model_dict = cls._dict_from_file(filename=filename)
return cls.model_validate(model_dict, **parse_obj_kwargs)

@classmethod
Expand Down
264 changes: 264 additions & 0 deletions flow360/component/simulation/framework/entity_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
from __future__ import annotations

import copy
from abc import ABCMeta
from collections import defaultdict
from typing import List, Union, get_args

import pydantic as pd

from flow360.component.simulation.framework.base_model import Flow360BaseModel
from flow360.component.simulation.framework.entity_registry import EntityRegistry
from flow360.log import log


class EntityBase(Flow360BaseModel, metaclass=ABCMeta):
"""
Base class for dynamic entity types.

Attributes:
_entity_type (str): A string representing the specific type of the entity.
This should be set in subclasses to differentiate between entity types.
Note this controls the granularity of the registry.
_is_generic(bool): A flag indicating whether the entity is a generic entity (constructed from metadata).
name (str): The name of the entity, used for identification and retrieval.
"""

_entity_type: str = None
_is_generic = False
name: str = pd.Field(frozen=True)

def __init__(self, **data):
"""
Initializes a new entity and registers it in the global registry.

Parameters:
data: Keyword arguments containing initial values for fields declared in the entity.
"""
super().__init__(**data)
assert self._entity_type is not None, "_entity_type is not defined in the entity class."

def copy(self, update=None, **kwargs) -> EntityBase:
"""
Creates a copy of the entity with compulsory updates.

Parameters:
update: A dictionary containing the updated attributes to apply to the copied entity.
**kwargs: Additional arguments to pass to the copy constructor.

Returns:
A copy of the entity with the specified updates.
"""
if update is None:
raise ValueError(
"Change is necessary when calling .copy() as there cannot be two identical entities at the same time. Please use update parameter to change the entity attributes."
)
if "name" not in update or update["name"] == self.name:
raise ValueError(
"Copying an entity requires a new name to be specified. Please provide a new name in the update dictionary."
)
return super().copy(update=update, **kwargs)


class _CombinedMeta(type(Flow360BaseModel), type):
pass
Comment on lines +84 to +85
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is new to me, why it inherits from both type(Flow360BaseModel) and type?

Copy link
Collaborator Author

@benflexcompute benflexcompute May 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it only inherit from type then python gives error: TypeError: metaclass conflict: the metaclass of a derived class must be a (non-strict) subclass of the metaclasses of all its bases. I had to use this workaround for this to work.



class _EntitiesListMeta(_CombinedMeta):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_EntityListMeta

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

def __getitem__(cls, entity_types):
"""
Creates a new class with the specified entity types as a list of stored entities.
"""
if not isinstance(entity_types, tuple):
entity_types = (entity_types,)
union_type = Union[entity_types]
annotations = {"stored_entities": List[union_type]}
new_cls = type(
f"{cls.__name__}[{','.join([t.__name__ for t in entity_types])}]",
(cls,),
{"__annotations__": annotations},
)
Comment on lines +89 to +101
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like a class factory. Looks very complex and anti pattern, why do we need it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is. This is to support syntax like: EntityList[GenericVolume, Box, Cylinder, str] so developers can easily specify what are allowed for that particular "entities" field.

return new_cls


def _remove_duplicate_entities(expanded_entities: List[EntityBase]):
"""
In the expanded entity list from `_get_expanded_entities` we will very likely have generic entities
which comes from asset metadata. These entities may have counterparts given by users. We remove the
generic ones when they have duplicate counterparts because the counterparts will likely have more info.

For example `entities = [my_mesh["*"], user_defined_zone]`. We need to remove duplicates from the expanded list.
"""
all_entities = defaultdict(list)

for entity in expanded_entities:
all_entities[entity.name].append(entity)

for entity_list in all_entities.values():
if len(entity_list) > 1:
for entity in entity_list:
if entity._is_generic and len(entity_list) > 1:
entity_list.remove(entity)

assert len(entity_list) == 1

return [entity_list[0] for entity_list in all_entities.values()]


class EntityList(Flow360BaseModel, metaclass=_EntitiesListMeta):
"""
The type accepting a list of entities or (name, registry) pair.

Attributes:
stored_entities (List[Union[EntityBase, Tuple[str, registry]]]): List of stored entities, which can be
instances of `Box`, `Cylinder`, or strings representing naming patterns.

Methods:
_format_input_to_list(cls, input: List) -> dict: Class method that formats the input to a
dictionary with the key 'stored_entities'.
_check_duplicate_entity_in_list(cls, values): Class method that checks for duplicate entities
in the list of stored entities.
_get_expanded_entities(self): Method that processes the stored entities to resolve any naming
patterns into actual entity references, expanding and filtering based on the defined
entity types.

"""

stored_entities: List = pd.Field()

@classmethod
def _get_valid_entity_types(cls):
"""Get the list of types that the entity list can accept."""
entity_field_type = cls.__annotations__.get("stored_entities")
if (
entity_field_type is not None
and hasattr(entity_field_type, "__origin__")
and entity_field_type.__origin__ is list
):
valid_types = get_args(entity_field_type)[0]
if hasattr(valid_types, "__origin__") and valid_types.__origin__ is Union:
valid_types = get_args(valid_types)
else:
valid_types = (valid_types,)
return valid_types
raise TypeError("Internal error, the metaclass for EntityList is not properly set.")

@classmethod
def _valid_individual_input(cls, input):
"""Validate each individual element in a list or as standalone entity."""
if isinstance(input, str) or isinstance(input, EntityBase):
return input
else:
raise ValueError(
f"Type({type(input)}) of input to `entities` ({input}) is not valid. Expected str or entity instance."
)

@pd.model_validator(mode="before")
@classmethod
def _format_input_to_list(cls, input: Union[dict, list]):
"""
Flatten List[EntityBase] and put into stored_entities.
"""
# Note:
# 1. str comes from Param. These will be expanded before submission
# as the user may change Param which affects implicit entities (farfield existence patch for example).
# 2. The List[EntityBase], comes from the Assets.
# 3. EntityBase comes from direct specification of entity in the list.
formated_input = []
valid_types = cls._get_valid_entity_types()
if isinstance(input, list):
if input == []:
raise ValueError("Invalid input type to `entities`, list is empty.")
for item in input:
if isinstance(item, list): # Nested list comes from assets
[cls._valid_individual_input(individual) for individual in item]
formated_input.extend(
[
individual
for individual in item
if isinstance(individual, tuple(valid_types))
]
)
else:
cls._valid_individual_input(item)
if isinstance(item, tuple(valid_types)):
formated_input.append(item)
elif isinstance(input, dict):
return dict(stored_entities=input["stored_entities"])
else: # Single reference to an entity
cls._valid_individual_input(input)
if isinstance(item, tuple(valid_types)):
formated_input.append(item)
return dict(stored_entities=formated_input)

@pd.field_validator("stored_entities", mode="after")
@classmethod
def _check_duplicate_entity_in_list(cls, values):
seen = []
for value in values:
if value in seen:
if isinstance(value, EntityBase):
log.warning(f"Duplicate entity found, name: {value.name}")
else:
log.warning(f"Duplicate entity found: {value}")
continue
seen.append(value)
return seen

def _get_expanded_entities(self, supplied_registry: EntityRegistry = None) -> List[EntityBase]:
"""
Processes `stored_entities` to resolve any naming patterns into actual entity
references, expanding and filtering based on the defined entity types. This ensures that
all entities referenced either directly or by pattern are valid and registered.

**Warning**:
This method has to be called during preprocessing stage of Param when all settings have
been finalized. This ensures that all entities are registered in the registry (by assets or param).
Maybe we check hash or something to ensure consistency/integrity?

Raises:
TypeError: If an entity does not match the expected type.
Returns:
Deep copy of the exapnded entities list.
"""

entities = getattr(self, "stored_entities", [])

valid_types = self.__class__._get_valid_entity_types()

expanded_entities = []

for entity in entities:
if isinstance(entity, str):
# Expand from supplied registry
if supplied_registry is None:
raise ValueError(
f"Internal error, registry is not supplied for entity ({entity}) expansion. "
)
# Expand based on naming pattern registered in the Registry
pattern_matched_entities = supplied_registry.find_by_name_pattern(entity)
# Filter pattern matched entities by valid types
expanded_entities.extend(
[
e
for e in pattern_matched_entities
if isinstance(e, tuple(valid_types)) and e not in expanded_entities
]
)
elif entity not in expanded_entities:
# Direct entity references are simply appended if they are of a valid type
expanded_entities.append(entity)

expanded_entities = _remove_duplicate_entities(expanded_entities)

if expanded_entities == []:
raise ValueError(
f"Failed to find any matching entity with {entities}. Please check the input to entities."
)
# TODO: As suggested by Runda. We better prompt user what entities are actually used/expanded to avoid user input error. We need a switch to turn it on or off.
return copy.deepcopy(expanded_entities)

def preprocess(self, supplied_registry: EntityRegistry = None):
"""Expand and overwrite self.stored_entities in preparation for submissin/serialization."""
self.stored_entities = self._get_expanded_entities(supplied_registry)
return self
Loading
Loading