Skip to content

Commit 6fa9f31

Browse files
authored
build: use pydantic-compat library (#124)
* build: use pydantic-compat * more usage * more changes * use config support * add dep * fix typing
1 parent 91fb059 commit 6fa9f31

15 files changed

+78
-230
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,5 @@ repos:
3636
additional_dependencies:
3737
- types-PyYAML
3838
- pydantic>=2
39+
- pydantic-compat
3940
files: "^src/"

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ classifiers = [
3232
"Typing :: Typed",
3333
]
3434
dynamic = ["version"]
35-
dependencies = ["pydantic >=1.7", "numpy"]
35+
dependencies = ["pydantic >=1.7", "numpy", "pydantic-compat >=0.0.1"]
3636

3737
# extras
3838
# https://peps.python.org/pep-0621/#dependencies-optional-dependencies

src/useq/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@
4545
"ZTopBottom",
4646
]
4747

48-
from useq._pydantic_compat import model_rebuild
4948

50-
model_rebuild(MDAEvent, MDASequence=MDASequence)
51-
model_rebuild(Position, MDASequence=MDASequence)
52-
del model_rebuild
49+
MDAEvent.model_rebuild(MDASequence=MDASequence)
50+
Position.model_rebuild(MDASequence=MDASequence)

src/useq/_base_model.py

Lines changed: 16 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@
1818

1919
import numpy as np
2020
from pydantic import BaseModel
21-
22-
from useq._pydantic_compat import PYDANTIC2, model_dump, model_fields
21+
from pydantic_compat import PydanticCompatMixin
2322

2423
if TYPE_CHECKING:
24+
from pydantic import ConfigDict
25+
2526
ReprArgs = Sequence[Tuple[Optional[str], Any]]
2627
IncEx = set[int] | set[str] | dict[int, Any] | dict[str, Any] | None
2728

@@ -31,21 +32,13 @@
3132
_Y = TypeVar("_Y", bound="UseqModel")
3233

3334

34-
class FrozenModel(BaseModel):
35-
if PYDANTIC2:
36-
model_config = {
37-
"populate_by_name": True,
38-
"extra": "ignore",
39-
"frozen": True,
40-
}
41-
42-
else:
43-
44-
class Config:
45-
allow_population_by_field_name = True
46-
extra = "ignore"
47-
frozen = True
48-
json_encoders: ClassVar[dict] = {MappingProxyType: dict}
35+
class FrozenModel(PydanticCompatMixin, BaseModel):
36+
model_config: ClassVar[ConfigDict] = {
37+
"populate_by_name": True,
38+
"extra": "ignore",
39+
"frozen": True,
40+
"json_encoders": {MappingProxyType: dict},
41+
}
4942

5043
def replace(self: _T, **kwargs: Any) -> _T:
5144
"""Return a new instance replacing specified kwargs with new values.
@@ -58,63 +51,22 @@ def replace(self: _T, **kwargs: Any) -> _T:
5851
will perform validation and casting on the new values, whereas `copy` assumes
5952
that all objects are valid and will not perform any validation or casting.
6053
"""
61-
state = model_dump(self, exclude={"uid"})
54+
state = self.model_dump(exclude={"uid"})
6255
return type(self)(**{**state, **kwargs})
6356

64-
if PYDANTIC2:
65-
# retain pydantic1's json method
66-
def json(
67-
self,
68-
*,
69-
indent: int | None = None, # type: ignore
70-
include: IncEx = None,
71-
exclude: IncEx = None, # type: ignore
72-
by_alias: bool = False,
73-
exclude_unset: bool = False,
74-
exclude_defaults: bool = False,
75-
exclude_none: bool = False, # type: ignore
76-
round_trip: bool = False,
77-
warnings: bool = True,
78-
) -> str:
79-
return super().model_dump_json(
80-
indent=indent,
81-
include=include,
82-
exclude=exclude,
83-
by_alias=by_alias,
84-
exclude_unset=exclude_unset,
85-
exclude_defaults=exclude_defaults,
86-
exclude_none=exclude_none,
87-
round_trip=round_trip,
88-
warnings=warnings,
89-
)
90-
91-
# we let this one be deprecated
92-
# def dict()
93-
94-
elif not TYPE_CHECKING:
95-
# Backport pydantic2 methods so that useq-0.1.0 can be used with pydantic1
96-
97-
def model_dump_json(self, **kwargs: Any) -> str:
98-
"""Backport of pydantic2's model_dump_json method."""
99-
return self.json(**kwargs)
100-
101-
def model_dump(self, **kwargs: Any) -> dict[str, Any]:
102-
"""Backport of pydantic2's model_dump_json method."""
103-
return self.dict(**kwargs)
104-
10557

10658
class UseqModel(FrozenModel):
10759
def __repr_args__(self) -> ReprArgs:
10860
"""Only show fields that are not None or equal to their default value."""
10961
return [
11062
(k, val)
11163
for k, val in super().__repr_args__()
112-
if k in model_fields(self)
64+
if k in self.model_fields
11365
and val
11466
!= (
11567
factory()
116-
if (factory := model_fields(self)[k].default_factory) is not None
117-
else model_fields(self)[k].default
68+
if (factory := self.model_fields[k].default_factory) is not None
69+
else self.model_fields[k].default
11870
)
11971
]
12072

@@ -133,7 +85,7 @@ def from_file(cls: Type[_Y], path: Union[str, Path]) -> _Y:
13385
else: # pragma: no cover
13486
raise ValueError(f"Unknown file type: {path.suffix}")
13587

136-
return cls.model_validate(obj) if PYDANTIC2 else cls.parse_obj(obj)
88+
return cls.model_validate(obj)
13789

13890
@classmethod
13991
def parse_file(cls: Type[_Y], path: Union[str, Path], **kwargs: Any) -> _Y:
@@ -180,8 +132,7 @@ def yaml(
180132
np.floating, lambda dumper, d: dumper.represent_float(float(d))
181133
)
182134

183-
data = model_dump(
184-
self,
135+
data = self.model_dump(
185136
include=include,
186137
exclude=exclude,
187138
by_alias=by_alias,

src/useq/_grid.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,24 @@
44
import math
55
from enum import Enum
66
from functools import partial
7-
from typing import Any, Callable, Iterator, NamedTuple, Optional, Sequence, Tuple, Union
7+
from typing import (
8+
Any,
9+
Callable,
10+
ClassVar,
11+
Iterator,
12+
NamedTuple,
13+
Optional,
14+
Sequence,
15+
Tuple,
16+
Union,
17+
)
818

919
import numpy as np
10-
from pydantic import Field
20+
from pydantic import ConfigDict, Field
21+
from pydantic_compat import field_validator
1122

1223
from useq._base_model import FrozenModel
13-
from useq._pydantic_compat import FROZEN, PYDANTIC2, field_validator
24+
from useq._pydantic_compat import FROZEN
1425

1526

1627
class RelativeTo(Enum):
@@ -123,13 +134,7 @@ class _GridPlan(FrozenModel):
123134
"""
124135

125136
# Overriding FrozenModel to make fov_width and fov_height mutable.
126-
if PYDANTIC2:
127-
model_config = {"validate_assignment": True, "frozen": False}
128-
else:
129-
130-
class Config:
131-
validate_assignment = True
132-
frozen = False
137+
model_config: ClassVar[ConfigDict] = {"validate_assignment": True, "frozen": False}
133138

134139
overlap: Tuple[float, float] = Field((0.0, 0.0), **FROZEN) # type: ignore
135140
mode: OrderMode = Field(OrderMode.row_wise_snake, **FROZEN) # type: ignore

src/useq/_hardware_autofocus.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from useq._actions import HardwareAutofocus
66
from useq._base_model import FrozenModel
77
from useq._mda_event import MDAEvent
8-
from useq._pydantic_compat import model_copy
98

109

1110
class AutoFocusPlan(FrozenModel):
@@ -46,7 +45,7 @@ def event(self, event: MDAEvent) -> Optional[MDAEvent]:
4645
if zplan and zplan.is_relative and "z" in event.index:
4746
updates["z_pos"] = event.z_pos - list(zplan)[event.index["z"]]
4847

49-
return model_copy(event, update=updates)
48+
return event.model_copy(update=updates)
5049

5150
def should_autofocus(self, event: MDAEvent) -> bool:
5251
"""Method that must be implemented by a subclass.

src/useq/_iter_sequence.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
)
1414
from useq._mda_event import Channel as EventChannel
1515
from useq._mda_event import MDAEvent
16-
from useq._pydantic_compat import model_construct, model_copy
1716
from useq._utils import AXES, Axis, _has_axes
1817
from useq._z import AnyZPlan # noqa: TCH001 # noqa: TCH001
1918

@@ -100,7 +99,7 @@ def iter_sequence(sequence: MDASequence) -> Iterator[MDAEvent]:
10099
for axis, idx in this_e.index.items()
101100
if idx != next_e.index[axis]
102101
):
103-
this_e = model_copy(this_e, update={"keep_shutter_open": True})
102+
this_e = this_e.model_copy(update={"keep_shutter_open": True})
104103
yield this_e
105104
this_e = next_e
106105
yield this_e
@@ -169,8 +168,8 @@ def _iter_sequence(
169168
if position and position.name:
170169
event_kwargs["pos_name"] = position.name
171170
if channel:
172-
event_kwargs["channel"] = model_construct(
173-
EventChannel, config=channel.config, group=channel.group
171+
event_kwargs["channel"] = EventChannel.model_construct(
172+
config=channel.config, group=channel.group
174173
)
175174
if channel.exposure is not None:
176175
event_kwargs["exposure"] = channel.exposure
@@ -205,8 +204,8 @@ def _iter_sequence(
205204
# if the sub-sequence doe not have an autofocus plan, we override it
206205
# with the parent sequence's autofocus plan
207206
if not sub_seq.autofocus_plan:
208-
sub_seq = model_copy(
209-
sub_seq, update={"autofocus_plan": autofocus_plan}
207+
sub_seq = sub_seq.model_copy(
208+
update={"autofocus_plan": autofocus_plan}
210209
)
211210

212211
# recurse into the sub-sequence
@@ -223,7 +222,7 @@ def _iter_sequence(
223222
elif position.sequence is not None and position.sequence.autofocus_plan:
224223
autofocus_plan = position.sequence.autofocus_plan
225224

226-
event = model_construct(MDAEvent, **event_kwargs)
225+
event = MDAEvent.model_construct(**event_kwargs)
227226
if autofocus_plan:
228227
af_event = autofocus_plan.event(event)
229228
if af_event:

src/useq/_mda_event.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919

2020
from useq._actions import AcquireImage, AnyAction
2121
from useq._base_model import UseqModel
22-
from useq._pydantic_compat import PYDANTIC2, field_serializer
22+
23+
try:
24+
from pydantic import field_serializer
25+
except ImportError:
26+
field_serializer = None # type: ignore
2327

2428
if TYPE_CHECKING:
2529
from useq._mda_sequence import MDASequence
@@ -166,7 +170,7 @@ def to_pycromanager(self) -> "PycroManagerEvent":
166170

167171
return to_pycromanager(self)
168172

169-
if PYDANTIC2:
173+
if field_serializer is not None:
170174
_si = field_serializer("index", mode="plain")(lambda v: dict(v))
171175
_sx = field_serializer("x_pos", mode="plain")(_float_or_none)
172176
_sy = field_serializer("y_pos", mode="plain")(_float_or_none)

src/useq/_mda_sequence.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,14 @@
77

88
import numpy as np
99
from pydantic import Field, PrivateAttr
10+
from pydantic_compat import field_validator, model_validator
1011

1112
from useq._base_model import UseqModel
1213
from useq._channel import Channel
1314
from useq._grid import AnyGridPlan, GridPosition # noqa: TCH001
1415
from useq._hardware_autofocus import AnyAutofocusPlan, AxesBasedAF
1516
from useq._iter_sequence import iter_sequence
1617
from useq._position import Position
17-
from useq._pydantic_compat import (
18-
field_validator,
19-
model_construct,
20-
model_dump,
21-
model_validator,
22-
pydantic_1_style_root_dict,
23-
)
2418
from useq._time import AnyTimePlan # noqa: TCH001
2519
from useq._utils import AXES, Axis, TimeEstimate, estimate_sequence_duration
2620
from useq._z import AnyZPlan # noqa: TCH001
@@ -176,7 +170,7 @@ def _validate_channels(cls, value: Any) -> Tuple[Channel, ...]:
176170
if isinstance(v, Channel):
177171
channels.append(v)
178172
elif isinstance(v, str):
179-
channels.append(model_construct(Channel, config=v))
173+
channels.append(Channel.model_construct(config=v))
180174
elif isinstance(v, dict):
181175
channels.append(Channel(**v))
182176
else: # pragma: no cover
@@ -230,22 +224,17 @@ def _validate_axis_order(cls, v: Any) -> str:
230224
@model_validator(mode="after")
231225
@classmethod
232226
def _validate_mda(cls, values: Any) -> Any:
233-
# this strange bit here is to deal with the fact that in pydantic1
234-
# root_validator after returned a dict of {field_name -> validated_value}
235-
# but in pydantic2 it returns the complete validated model instance
236-
_values = pydantic_1_style_root_dict(cls, values)
237-
238-
if "axis_order" in _values:
227+
if values.axis_order:
239228
cls._check_order(
240-
_values["axis_order"],
241-
z_plan=_values.get("z_plan"),
242-
stage_positions=_values.get("stage_positions", ()),
243-
channels=_values.get("channels", ()),
244-
grid_plan=_values.get("grid_plan"),
245-
autofocus_plan=_values.get("autofocus_plan"),
229+
values.axis_order,
230+
z_plan=values.z_plan,
231+
stage_positions=values.stage_positions,
232+
channels=values.channels,
233+
grid_plan=values.grid_plan,
234+
autofocus_plan=values.autofocus_plan,
246235
)
247-
if "stage_positions" in _values:
248-
for p in _values["stage_positions"]:
236+
if values.stage_positions:
237+
for p in values.stage_positions:
249238
if hasattr(p, "sequence") and getattr(
250239
p.sequence, "keep_shutter_open_across", None
251240
): # pragma: no cover
@@ -259,7 +248,7 @@ def __eq__(self, other: Any) -> bool:
259248
"""Return `True` if two `MDASequences` are equal (uid is excluded)."""
260249
if isinstance(other, MDASequence):
261250
return bool(
262-
model_dump(self, exclude={"uid"}) == model_dump(other, exclude={"uid"})
251+
self.model_dump(exclude={"uid"}) == other.model_dump(exclude={"uid"})
263252
)
264253
else:
265254
return False

0 commit comments

Comments
 (0)