Skip to content

Implement the framework for entities. #258

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 125 additions & 29 deletions flow360/component/simulation/framework/entity_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,30 @@

import copy
from abc import ABCMeta
from collections import defaultdict
from typing import List, Union, get_args

import pydantic as pd

from flow360.component.simulation.framework.base_model import Flow360BaseModel
from flow360.component.simulation.framework.entity_registry import EntityRegistry
from flow360.log import log


class EntityBase(Flow360BaseModel, metaclass=ABCMeta):
"""
Base class for dynamic entity types with automatic registration upon instantiation.
Base class for dynamic entity types.

Attributes:
_entity_type (str): A string representing the specific type of the entity.
This should be set in subclasses to differentiate between entity types.
Note this controls the granularity of the registry.
_is_generic(bool): A flag indicating whether the entity is a generic entity (constructed from metadata).
name (str): The name of the entity, used for identification and retrieval.
"""

_entity_type: str = None
_is_generic = False
name: str = pd.Field(frozen=True)

def __init__(self, **data):
Expand All @@ -33,7 +37,6 @@ def __init__(self, **data):
"""
super().__init__(**data)
assert self._entity_type is not None, "_entity_type is not defined in the entity class."
EntityRegistry.register(self)

def copy(self, update=None, **kwargs) -> EntityBase:
"""
Expand All @@ -48,13 +51,13 @@ def copy(self, update=None, **kwargs) -> EntityBase:
"""
if update is None:
raise ValueError(
"Change is necessary when copying an entity as there cannot be two identical entities at the same time. Please use update parameter to change the entity attributes."
"Change is necessary when calling .copy() as there cannot be two identical entities at the same time. Please use update parameter to change the entity attributes."
)
if "name" not in update or update["name"] == self.name:
raise ValueError(
"Copying an entity requires a new name to be specified. Please provide a new name in the update dictionary."
)
super().copy(update=update, **kwargs)
return super().copy(update=update, **kwargs)


class _CombinedMeta(type(Flow360BaseModel), type):
Expand All @@ -78,57 +81,141 @@ def __getitem__(cls, entity_types):
return new_cls


def _remove_duplicate_entities(expanded_entities: List[EntityBase]):
"""
In the expanded entity list from `_get_expanded_entities` we will very likely have generic entities
which comes from asset metadata. These entities may have counterparts given by users. We remove the
generic ones when they have duplicate counterparts because the counterparts will likely have more info.

For example `entities = [my_mesh["*"], user_defined_zone]`. We need to remove duplicates from the expanded list.
"""
all_entities = defaultdict(list)

for entity in expanded_entities:
all_entities[entity.name].append(entity)

for entity_list in all_entities.values():
if len(entity_list) > 1:
for entity in entity_list:
if entity._is_generic and len(entity_list) > 1:
entity_list.remove(entity)

assert len(entity_list) == 1

return [entity_list[0] for entity_list in all_entities.values()]


class EntityList(Flow360BaseModel, metaclass=_EntitiesListMeta):
"""
Represents a collection of volume entities in the Flow360 simulation.
The type accepting a list of entities or (name, registry) pair.

Attributes:
stored_entities (List[Union[Box, Cylinder, str]]): List of stored entities, which can be
stored_entities (List[Union[EntityBase, Tuple[str, registry]]]): List of stored entities, which can be
instances of `Box`, `Cylinder`, or strings representing naming patterns.

Methods:
format_input_to_list(cls, input: List) -> dict: Class method that formats the input to a
_format_input_to_list(cls, input: List) -> dict: Class method that formats the input to a
dictionary with the key 'stored_entities'.
check_duplicate_entity_in_list(cls, values): Class method that checks for duplicate entities
_check_duplicate_entity_in_list(cls, values): Class method that checks for duplicate entities
in the list of stored entities.
get_expanded_entities(self): Method that processes the stored entities to resolve any naming
_get_expanded_entities(self): Method that processes the stored entities to resolve any naming
patterns into actual entity references, expanding and filtering based on the defined
entity types.

"""

stored_entities: List = pd.Field()

@pd.model_validator(mode="before")
@classmethod
def format_input_to_list(cls, input: List) -> dict:
def _get_valid_entity_types(cls):
"""Get the list of types that the entity list can accept."""
entity_field_type = cls.__annotations__.get("stored_entities")
if (
entity_field_type is not None
and hasattr(entity_field_type, "__origin__")
and entity_field_type.__origin__ is list
):
valid_types = get_args(entity_field_type)[0]
if hasattr(valid_types, "__origin__") and valid_types.__origin__ is Union:
valid_types = get_args(valid_types)
else:
valid_types = (valid_types,)
return valid_types
raise TypeError("Internal error, the metaclass for EntityList is not properly set.")

@classmethod
def _valid_individual_input(cls, input):
"""Validate each individual element in a list or as standalone entity."""
if isinstance(input, str) or isinstance(input, EntityBase):
return dict(stored_entities=[input])
elif isinstance(input, list):
return input
else:
raise ValueError(
f"Type({type(input)}) of input to `entities` ({input}) is not valid. Expected str or entity instance."
)

@pd.model_validator(mode="before")
@classmethod
def _format_input_to_list(cls, input: Union[dict, list]):
"""
Flatten List[EntityBase] and put into stored_entities.
"""
# Note:
# 1. str comes from Param. These will be expanded before submission
# as the user may change Param which affects implicit entities (farfield existence patch for example).
# 2. The List[EntityBase], comes from the Assets.
# 3. EntityBase comes from direct specification of entity in the list.
formated_input = []
valid_types = cls._get_valid_entity_types()
if isinstance(input, list):
if input == []:
raise ValueError("Invalid input type to `entities`, list is empty.")
return dict(stored_entities=input)
else:
raise ValueError(f"Invalid input type to `entities`: {type(input)}")
for item in input:
if isinstance(item, list): # Nested list comes from assets
[cls._valid_individual_input(individual) for individual in item]
formated_input.extend(
[
individual
for individual in item
if isinstance(individual, tuple(valid_types))
]
)
else:
cls._valid_individual_input(item)
if isinstance(item, tuple(valid_types)):
formated_input.append(item)
elif isinstance(input, dict):
return dict(stored_entities=input["stored_entities"])
else: # Single reference to an entity
cls._valid_individual_input(input)
if isinstance(item, tuple(valid_types)):
formated_input.append(item)
return dict(stored_entities=formated_input)

@pd.field_validator("stored_entities", mode="after")
@classmethod
def check_duplicate_entity_in_list(cls, values):
def _check_duplicate_entity_in_list(cls, values):
seen = []
for value in values:
if value in seen:
if isinstance(value, EntityBase):
raise ValueError(f"Duplicate entity found, name: {value.name}")
raise ValueError(f"Duplicate entity found: {value}")
log.warning(f"Duplicate entity found, name: {value.name}")
else:
log.warning(f"Duplicate entity found: {value}")
continue
seen.append(value)
return values
return seen

def get_expanded_entities(self):
def _get_expanded_entities(self, supplied_registry: EntityRegistry = None) -> List[EntityBase]:
"""
Processes `stored_entities` to resolve any naming patterns into actual entity
references, expanding and filtering based on the defined entity types. This ensures that
all entities referenced either directly or by pattern are valid and registered.

**Warning**:
This method has to be called during preprocessing stage of Param when all settings have
been finalized. This ensures that all entities are registered in the registry (by assets or param).
Maybe we check hash or something to ensure consistency/integrity?

Raises:
TypeError: If an entity does not match the expected type.
Returns:
Expand All @@ -137,19 +224,19 @@ def get_expanded_entities(self):

entities = getattr(self, "stored_entities", [])

entity_field_type = self.__annotations__["stored_entities"]
if hasattr(entity_field_type, "__origin__") and entity_field_type.__origin__ is list:
valid_types = get_args(entity_field_type)[0]
if hasattr(valid_types, "__origin__") and valid_types.__origin__ is Union:
valid_types = get_args(valid_types)
else:
valid_types = (valid_types,)
valid_types = self.__class__._get_valid_entity_types()

expanded_entities = []

for entity in entities:
if isinstance(entity, str):
# Expand from supplied registry
if supplied_registry is None:
raise ValueError(
f"Internal error, registry is not supplied for entity ({entity}) expansion. "
)
# Expand based on naming pattern registered in the Registry
pattern_matched_entities = EntityRegistry.find_by_name_pattern(entity)
pattern_matched_entities = supplied_registry.find_by_name_pattern(entity)
# Filter pattern matched entities by valid types
expanded_entities.extend(
[
Expand All @@ -161,8 +248,17 @@ def get_expanded_entities(self):
elif entity not in expanded_entities:
# Direct entity references are simply appended if they are of a valid type
expanded_entities.append(entity)

expanded_entities = _remove_duplicate_entities(expanded_entities)

if expanded_entities == []:
raise ValueError(
f"Failed to find any matching entity with {entities}. Please check the input to entities."
)
# TODO: As suggested by Runda. We better prompt user what entities are actually used/expanded to avoid user input error. We need a switch to turn it on or off.
return copy.deepcopy(expanded_entities)

def preprocess(self, supplied_registry: EntityRegistry = None):
"""Expand and overwrite self.stored_entities in preparation for submissin/serialization."""
self.stored_entities = self._get_expanded_entities(supplied_registry)
return self
37 changes: 18 additions & 19 deletions flow360/component/simulation/framework/entity_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,22 @@ class EntityRegistry:
_registry (Dict[str, List[EntityBase]]): A dictionary that maps entity types to lists of instances.
"""

_registry = {}
_registry = None

@classmethod
def register(cls, entity):
def __init__(self) -> None:
self._registry = {}

def register(self, entity):
"""
Registers an entity in the registry under its type.

Parameters:
entity (EntityBase): The entity instance to register.
"""
if entity._entity_type not in cls._registry:
cls._registry[entity._entity_type] = []
for existing_entity in cls._registry[entity._entity_type]:
if entity._entity_type not in self._registry:
self._registry[entity._entity_type] = []

for existing_entity in self._registry[entity._entity_type]:
if existing_entity.name == entity.name:
# Same type and same name. Now we try to update existing entity with new values.
try:
Expand All @@ -60,10 +63,9 @@ def register(cls, entity):
log.debug("Merge failed unexpectly: %s", e)
raise

cls._registry[entity._entity_type].append(entity)
self._registry[entity._entity_type].append(entity)

@classmethod
def get_entities(cls, entity_type):
def get_all_entities_of_given_type(self, entity_type):
"""
Retrieves all entities of a specified type.

Expand All @@ -73,10 +75,9 @@ def get_entities(cls, entity_type):
Returns:
List[EntityBase]: A list of registered entities of the specified type.
"""
return cls._registry.get(entity_type._entity_type.default, [])
return self._registry.get(entity_type._entity_type.default, [])

@classmethod
def find_by_name_pattern(cls, pattern: str):
def find_by_name_pattern(self, pattern: str):
"""
Finds all registered entities whose names match a given pattern.

Expand All @@ -88,23 +89,21 @@ def find_by_name_pattern(cls, pattern: str):
"""
matched_entities = []
regex = re.compile(pattern.replace("*", ".*"))
for entity_list in cls._registry.values():
for entity_list in self._registry.values():
matched_entities.extend(filter(lambda x: regex.match(x.name), entity_list))
return matched_entities

@classmethod
def show(cls):
def show(self):
"""
Prints a list of all registered entities, grouped by type.
"""
for entity_type, entities in cls._registry.items():
for entity_type, entities in self._registry.items():
print(f"Entities of type '{entity_type}':")
for entity in entities:
print(f" - {entity}")

@classmethod
def clear(cls):
def clear(self):
"""
Clears all entities from the registry.
"""
cls._registry.clear()
self._registry.clear()
Loading
Loading