Skip to content

Entity Registry in SimulationParams #328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
197 changes: 154 additions & 43 deletions flow360/component/simulation/framework/entity_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,36 @@

import copy
from abc import ABCMeta
from collections import defaultdict
from typing import List, Optional, Union, get_args, get_origin

import numpy as np
import pydantic as pd

from flow360.component.simulation.framework.base_model import Flow360BaseModel
from flow360.component.simulation.framework.entity_registry import EntityRegistry
from flow360.log import log


class MergeConflictError(Exception):
pass


class EntityBase(Flow360BaseModel, metaclass=ABCMeta):
"""
Base class for dynamic entity types.

Attributes:
_entity_type (str): A string representing the specific type of the entity.
This should be set in subclasses to differentiate between entity types.
Warning:
This controls the granularity of the registry and must be unique for each entity type and it is **strongly recommended NOT** to change it as it will bring up compatability problems.
_auto_constructed (bool): A flag indicating whether the entity is automatically constructed
by assets using their metadata. This means that the entity is not directly
specified by the user and contains less information than user-defined entities.

name (str): The name of the entity, used for identification and retrieval.
private_attribute_registry_bucket_name (str):
A string representing the specific type of the entity.
This should be set in subclasses to differentiate between entity types.
Warning:
This controls the granularity of the registry and must be unique for each entity type and it is **strongly recommended NOT** to change it as it will bring up compatability problems.

name (str):
The name of the entity instance, used for identification and retrieval.
"""

_entity_type: str = None
_auto_constructed = False
private_attribute_registry_bucket_name: str = "Invalid"
private_attribute_entity_type_name: str = "Invalid"
name: str = pd.Field(frozen=True)

def __init__(self, **data):
Expand All @@ -40,8 +42,14 @@ def __init__(self, **data):
data: Keyword arguments containing initial values for fields declared in the entity.
"""
super().__init__(**data)
if self.entity_type is None:
raise NotImplementedError("_entity_type is not defined in the entity class.")
if self.entity_bucket == "Invalid":
raise NotImplementedError(
f"private_attribute_registry_bucket_name is not defined in the entity class: {self.__class__.__name__}."
)
if self.entity_type == "Invalid":
raise NotImplementedError(
f"private_attribute_entity_type_name is not defined in the entity class: {self.__class__.__name__}."
)

def copy(self, update=None, **kwargs) -> EntityBase:
"""
Expand All @@ -64,21 +72,35 @@ def copy(self, update=None, **kwargs) -> EntityBase:
)
return super().copy(update=update, **kwargs)

def __eq__(self, other):
"""Defines the equality comparison for entities to support usage in UniqueItemList."""
if isinstance(other, EntityBase):
return (self.name + "-" + self.__class__.__name__) == (
other.name + "-" + other.__class__.__name__
)
return False

@property
def entity_bucket(self) -> str:
return self.private_attribute_registry_bucket_name

@entity_bucket.setter
def entity_bucket(self, value: str):
raise AttributeError("Cannot modify the bucket to which the entity belongs.")

@property
def entity_type(self) -> str:
return self._entity_type
return self.private_attribute_entity_type_name

@entity_type.setter
def entity_type(self, value: str):
raise AttributeError("Cannot modify _entity_type")
raise AttributeError("Cannot modify the name of entity class.")

@property
def auto_constructed(self) -> str:
return self._auto_constructed
def __str__(self) -> str:
return "\n".join([f" {attr}: {value}" for attr, value in self.__dict__.items()])

@auto_constructed.setter
def auto_constructed(self, value: str):
raise AttributeError("Cannot modify _auto_constructed")
def _is_generic(self):
return self.__class__.__name__.startswith("Generic")


class _CombinedMeta(type(Flow360BaseModel), type):
Expand All @@ -104,27 +126,102 @@ def __getitem__(cls, entity_types):
return new_cls


def __combine_bools(input_data):
# If the input is a single boolean, return it directly
if isinstance(input_data, bool):
return input_data
# If the input is a numpy ndarray, flatten it
elif isinstance(input_data, np.ndarray):
input_data = input_data.ravel()
# If the input is not a boolean or an ndarray, assume it's an iterable of booleans
return all(input_data)


def _merge_objects(obj_old: EntityBase, obj_new: EntityBase) -> EntityBase:
"""
Merges obj_new into obj_old, raising an exception if there are conflicts.
Ideally the obj_old should be a non-generic one.
Parameters:
obj_old: The original object to merge into.
obj_new: The new object to merge into the original object.
"""

if obj_new.name != obj_old.name:
raise MergeConflictError(
"Make sure merge is intended as the names of two entities are different."
)

if obj_new._is_generic() == False and obj_old._is_generic() == True:
# swap so that obj_old is **non-generic** and obj_new is **generic**
obj_new, obj_old = obj_old, obj_new

# Check the two objects are mergeable
if obj_new._is_generic() == False and obj_old._is_generic() == False:
if obj_new.__class__ != obj_old.__class__:
raise MergeConflictError(
f"Cannot merge objects of different class: {obj_old.__class__.__name__} and {obj_new.__class__.__name__}"
)

for attr, value in obj_new.__dict__.items():
if attr in [
"private_attribute_entity_type_name",
"private_attribute_registry_bucket_name",
"name",
]:
continue
if attr in obj_old.__dict__:
found_conflict = __combine_bools(obj_old.__dict__[attr] != value)
if found_conflict:
if obj_old.__dict__[attr] is None:
# Populate obj_old with new info from lower priority object
obj_old.__dict__[attr] = value
elif obj_new.__dict__[attr] is None:
# Ignore difference from lower priority object
continue
else:
raise MergeConflictError(
f"Conflict on attribute '{attr}': {obj_old.__dict__[attr]} != {value}"
)
# for new attr from new object, we just add it to the old object.
if attr in obj_old.model_fields.keys():
obj_old.__dict__[attr] = value

return obj_old


def _remove_duplicate_entities(expanded_entities: List[EntityBase]):
"""
In the expanded entity list from `_get_expanded_entities` we will very likely have generic entities
which comes from asset metadata. These entities may have counterparts given by users. We remove the
generic ones when they have duplicate counterparts because the counterparts will likely have more info.

For example `entities = [my_mesh["*"], user_defined_zone]`. We need to remove duplicates from the expanded list.
which comes from asset metadata. These entities may have counterparts given by users. We will try to update the
non-generic entities with the metadata contained within generic ones.
For example `entities = [my_mesh["*"], user_defined_zone]`. We need to keep the `user_defined_zone` while updating
it with the boundaries coming from mesh metadata in expanded list.
"""
all_entities = defaultdict(list)
all_entities = {}

for entity in expanded_entities:
if entity.name not in all_entities:
all_entities[entity.name] = []
all_entities[entity.name].append(entity)

for entity_list in all_entities.values():
for name, entity_list in all_entities.items():
if len(entity_list) > 1:
# step 1: find one instance that is non-generic if any
for base_index, entity in enumerate(entity_list):
if entity._is_generic() == False:
break
for index, entity in enumerate(entity_list):
if index == base_index:
continue # no merging into self
entity_list[base_index] = _merge_objects(entity_list[base_index], entity)
entity_list.remove(entity)

if len(entity_list) != 1:
error_message = f"Duplicate entities found for {name}."
for entity in entity_list:
if entity._auto_constructed and len(entity_list) > 1:
entity_list.remove(entity)

assert len(entity_list) == 1

error_message += f"\n{entity}\n"
error_message += "Please remove duplicates."
raise ValueError(error_message)
return [entity_list[0] for entity_list in all_entities.values()]


Expand Down Expand Up @@ -240,7 +337,12 @@ def _check_duplicate_entity_in_list(cls, values):
seen.append(value)
return seen

def _get_expanded_entities(self, supplied_registry: EntityRegistry = None) -> List[EntityBase]:
def _get_expanded_entities(
self,
supplied_registry=None,
expect_supplied_registry: bool = True,
create_hard_copy: bool = True,
) -> List[EntityBase]:
"""
Processes `stored_entities` to resolve any naming patterns into actual entity
references, expanding and filtering based on the defined entity types. This ensures that
Expand All @@ -254,7 +356,7 @@ def _get_expanded_entities(self, supplied_registry: EntityRegistry = None) -> Li
Raises:
TypeError: If an entity does not match the expected type.
Returns:
Deep copy of the exapnded entities list.
Exapnded entities list.
"""

entities = getattr(self, "stored_entities", [])
Expand All @@ -270,11 +372,14 @@ def _get_expanded_entities(self, supplied_registry: EntityRegistry = None) -> Li
if isinstance(entity, str):
# Expand from supplied registry
if supplied_registry is None:
raise ValueError(
f"Internal error, registry is not supplied for entity ({entity}) expansion."
)
if expect_supplied_registry == False:
continue
else:
raise ValueError(
f"Internal error, registry is not supplied for entity ({entity}) expansion."
)
# Expand based on naming pattern registered in the Registry
pattern_matched_entities = supplied_registry.find_by_name_pattern(entity)
pattern_matched_entities = supplied_registry.find_by_naming_pattern(entity)
# Filter pattern matched entities by valid types
expanded_entities.extend(
[
Expand All @@ -294,9 +399,15 @@ def _get_expanded_entities(self, supplied_registry: EntityRegistry = None) -> Li
f"Failed to find any matching entity with {entities}. Please check the input to entities."
)
# TODO: As suggested by Runda. We better prompt user what entities are actually used/expanded to avoid user input error. We need a switch to turn it on or off.
return copy.deepcopy(expanded_entities)
if create_hard_copy == True:
return copy.deepcopy(expanded_entities)
else:
return expanded_entities

def preprocess(self, supplied_registry: EntityRegistry = None, **kwargs):
"""Expand and overwrite self.stored_entities in preparation for submissin/serialization."""
def preprocess(self, supplied_registry=None, **kwargs):
"""
Expand and overwrite self.stored_entities in preparation for submissin/serialization.
Should only be called as late as possible to incoperate all possible changes.
"""
self.stored_entities = self._get_expanded_entities(supplied_registry)
return super().preprocess(self, **kwargs)
Loading
Loading