Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 57 additions & 2 deletions libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,57 @@ def loads_typed(self, data: tuple[str, bytes]) -> Any:
EXT_NUMPY_ARRAY = 6


def _try_get_pydantic_v2_generic_type_info(
pydantic_type: Any,
) -> tuple[str, dict]:
from collections import defaultdict

info_dict: dict[tuple[str, str], Any] = defaultdict(list)

generics = getattr(pydantic_type, "__pydantic_generic_metadata__", {})
origin_class = generics.get("origin")
type_args = generics.get("args", [])

if not origin_class or not type_args:
return "", info_dict

for arg in type_args:
child_generics = getattr(arg, "__pydantic_generic_metadata__", {})

if child_generics.get("origin"):
origin_cls_name, sub_dict = _try_get_pydantic_v2_generic_type_info(arg)
info_dict[(origin_class.__module__, origin_cls_name)].append(sub_dict)
else:
info_dict[(origin_class.__module__, origin_class.__name__)].append(
(arg.__module__, arg.__name__)
)

return origin_class.__name__, info_dict


def _build_generic_pydantic_v2_type(
module_name: str, orign_cls_name: str, type_info_dict: dict
) -> Any:
generic_type_arg_info = type_info_dict[(module_name, orign_cls_name)]
origin_cls = getattr(importlib.import_module(module_name), orign_cls_name)
generic_type_args = []
for arg_info in generic_type_arg_info:
if isinstance(arg_info, dict):
m, c = next(iter(arg_info))
type_arg = _build_generic_pydantic_v2_type(m, c, arg_info)
else:
type_arg = getattr(importlib.import_module(arg_info[0]), arg_info[1])
generic_type_args.append(type_arg)
cls = origin_cls.__class_getitem__(tuple(generic_type_args))

return cls


def _msgpack_default(obj: Any) -> str | ormsgpack.Ext:
if hasattr(obj, "model_dump") and callable(obj.model_dump): # pydantic v2
origin_cls_name, generic_type_info = _try_get_pydantic_v2_generic_type_info(
obj.__class__
)
return ormsgpack.Ext(
EXT_PYDANTIC_V2,
_msgpack_enc(
Expand All @@ -265,6 +314,8 @@ def _msgpack_default(obj: Any) -> str | ormsgpack.Ext:
obj.__class__.__name__,
obj.model_dump(),
"model_validate_json",
origin_cls_name,
generic_type_info,
),
),
)
Expand Down Expand Up @@ -544,8 +595,12 @@ def _msgpack_ext_hook(code: int, data: bytes) -> Any:
tup = ormsgpack.unpackb(
data, ext_hook=_msgpack_ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
)
# module, name, kwargs, method
cls = getattr(importlib.import_module(tup[0]), tup[1])
# module, name, kwargs, method, generic_type_info, origin_cls_name
if len(tup) > 4 and tup[4] and isinstance(tup[5], dict) and len(tup[5]) > 0:
# serialized data has generic type info
cls = _build_generic_pydantic_v2_type(tup[0], tup[4], tup[5])
else:
cls = getattr(importlib.import_module(tup[0]), tup[1])
try:
return cls(**tup[2])
except Exception:
Expand Down
31 changes: 31 additions & 0 deletions libs/checkpoint/tests/test_jsonplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from decimal import Decimal
from enum import Enum
from ipaddress import IPv4Address
from typing import Generic, TypeVar
from zoneinfo import ZoneInfo

import dataclasses_json
Expand All @@ -24,6 +25,14 @@
)
from langgraph.store.base import Item

TInner = TypeVar("TInner", bound=BaseModel)


class MyPydanticGeneric(BaseModel, Generic[TInner]):
foo: str
bar: int
inner: TInner


class InnerPydantic(BaseModel):
hello: str
Expand Down Expand Up @@ -469,3 +478,25 @@ def test_serde_jsonplus_pandas_series(series: pd.Series) -> None:
result = serde.loads_typed(dumped)

assert result.equals(series)


def test_serde_jsonplus_with_pydantic_generic() -> None:
instance = MyPydanticGeneric[MyPydanticGeneric[InnerPydantic]](
foo="foo",
bar=1,
inner=MyPydanticGeneric[InnerPydantic](
foo="inner-foo", bar=2, inner=InnerPydantic(hello="hello")
),
)

serde = JsonPlusSerializer()
dumped = serde.dumps_typed(instance)
assert dumped[0] == "msgpack"
result = serde.loads_typed(dumped)
assert instance == result

# json mode is not affected by the fix
serde = JsonPlusSerializer(__unpack_ext_hook__=_msgpack_ext_hook_to_json)

json_result = serde.loads_typed(dumped)
assert json_result == instance.model_dump()