Skip to content

Commit 1fd5e66

Browse files
authored
[ENH] Implement GeoModel serialization with Pydantic (#1027)
# Description Refactored GeoModel to use Pydantic for serialization and deserialization. This change enables saving and loading models to/from JSON format while handling large numpy arrays efficiently. Key changes: - Converted GeoModel from a dataclass to a Pydantic BaseModel - Added serialization support for numpy arrays in grid, surface points, and orientations - Implemented context-based injection for large binary data during deserialization - Added from_args constructor to maintain backward compatibility - Updated model initialization across the codebase - Added gitignore entry for test/temp directory - Added tests for model serialization Relates to #serialization-support # Checklist - [x] My code uses type hinting for function and method arguments and return values. - [x] I have created tests which cover my code. - [x] The test code either 1. demonstrates at least one valuable use case (e.g. integration tests) or 2. verifies that outputs are as expected for given inputs (e.g. unit tests). - [x] New tests pass locally with my changes.
2 parents 57072cb + aff365e commit 1fd5e66

File tree

20 files changed

+801
-173
lines changed

20 files changed

+801
-173
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,4 @@ examples/tutorials/z_other_tutorials/json_io/multiple_series_faults_computed.jso
181181
# Generated JSON files from examples
182182
examples/tutorials/z_other_tutorials/json_io/combination_model.json
183183
examples/tutorials/z_other_tutorials/json_io/combination_model_computed.json
184+
/test/temp/

docs/developers_notes/dev_log/2025_05.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55

66
# TODO:
77
-[ ] Saving and loading models
8-
- [ ] Make tests passing for InterpOptions serializable
9-
- [ ] Try to make BaseModel dirctly to StructuralFrame
10-
- [ ] Try to make Serializable directly to Grid
8+
- [x] Make tests passing for InterpOptions serializable
9+
- [x] Dealing with large numpy arrays
10+
- [x] Trying to have a better implementation for deserializing complex fields
11+
- [ ] Make save and load function
1112
-[ ] Better api for nugget effect optimization
1213

1314
## Saving models

examples/examples/real/mik.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@
305305
all_surface_points_coords: gp.data.SurfacePointsTable = structural_frame.surface_points_copy
306306
extent_from_data = all_surface_points_coords.xyz.min(axis=0), all_surface_points_coords.xyz.max(axis=0)
307307
# Initialize GeoModel
308-
geo_model = gp.data.GeoModel(
308+
geo_model = gp.data.GeoModel.from_args(
309309
name="Stratigraphic Pile",
310310
structural_frame=structural_frame,
311311
grid=gp.data.Grid(

gempy/API/initialization_API.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def create_geomodel(
9292
case _:
9393
pass
9494

95-
geo_model: GeoModel = GeoModel(
95+
geo_model: GeoModel = GeoModel.from_args(
9696
name=project_name,
9797
structural_frame=structural_frame,
9898
grid=grid,

gempy/core/data/_data_points_helpers.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
from typing import Sequence
1+
import hashlib
2+
3+
from typing import Sequence
24

35
import numpy as np
46

57

68
def structural_element_hasher(i: int, name: str, hash_length: int = 8) -> int:
79
# Get the last 'hash_length' digits from the hash
8-
name_hash = abs(hash(name)) % (10 ** hash_length)
9-
10+
name_hash = int(hashlib.md5(name.encode('utf-8')).hexdigest(), 16) % (10 ** hash_length)
1011
return i * (10 ** hash_length) + name_hash
1112

1213

gempy/core/data/encoders/__init__.py

Whitespace-only changes.
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from contextlib import contextmanager
2+
3+
from contextvars import ContextVar
4+
5+
import numpy as np
6+
from pydantic import BeforeValidator
7+
8+
9+
def convert_to_arrays(values, keys):
10+
for key in keys:
11+
if key in values and not isinstance(values[key], np.ndarray):
12+
values[key] = np.array(values[key])
13+
return values
14+
15+
16+
def validate_numpy_array(v):
17+
return np.array(v) if v is not None else None
18+
19+
20+
def instantiate_if_necessary(data: dict, key: str, type: type) -> None:
21+
"""
22+
Creates instances of the specified type for a dictionary key if the key exists and its
23+
current type does not match the specified type. This function modifies the dictionary
24+
in place by converting the value associated with the key into an instance of the given
25+
type.
26+
27+
This is typically used when a dictionary contains data that needs to be represented as
28+
objects of a specific class type.
29+
30+
Args:
31+
data (dict): The dictionary containing the key-value pair to inspect and possibly
32+
convert.
33+
key (str): The key in the dictionary whose value should be inspected and converted
34+
if necessary.
35+
type (type): The type to which the value of `key` should be converted, if it is not
36+
already an instance of the type.
37+
38+
Returns:
39+
None
40+
"""
41+
if key in data and not isinstance(data[key], type):
42+
data[key] = type(**data[key])
43+
44+
numpy_array_short_validator = BeforeValidator(validate_numpy_array)
45+
46+
# First, create a context variable
47+
loading_model_context = ContextVar('loading_model_context', default={})
48+
49+
@contextmanager
50+
def loading_model_injection(surface_points_binary: np.ndarray, orientations_binary: np.ndarray):
51+
token = loading_model_context.set({
52+
'surface_points_binary': surface_points_binary,
53+
'orientations_binary' : orientations_binary
54+
})
55+
try:
56+
yield
57+
finally:
58+
loading_model_context.reset(token)
59+
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import numpy as np
2+
3+
4+
def encode_numpy_array(array: np.ndarray):
5+
# Check length
6+
if array.size > 10:
7+
return []
8+
return array.tolist()

0 commit comments

Comments
 (0)