From f4021dc664b2b311ee217b6fa140ecef67af50a0 Mon Sep 17 00:00:00 2001 From: Sylphia Windy Date: Mon, 8 Sep 2025 18:34:25 +0800 Subject: [PATCH] fix: generic type args should be serialized for de-serialization, closes #6102 --- .../langgraph/checkpoint/serde/jsonplus.py | 59 ++++++++++++++++++- libs/checkpoint/tests/test_jsonplus.py | 31 ++++++++++ 2 files changed, 88 insertions(+), 2 deletions(-) diff --git a/libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py b/libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py index 45980c64eb..4cd6a6bb0f 100644 --- a/libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py +++ b/libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py @@ -255,8 +255,57 @@ def loads_typed(self, data: tuple[str, bytes]) -> Any: EXT_NUMPY_ARRAY = 6 +def _try_get_pydantic_v2_generic_type_info( + pydantic_type: Any, +) -> tuple[str, dict]: + from collections import defaultdict + + info_dict: dict[tuple[str, str], Any] = defaultdict(list) + + generics = getattr(pydantic_type, "__pydantic_generic_metadata__", {}) + origin_class = generics.get("origin") + type_args = generics.get("args", []) + + if not origin_class or not type_args: + return "", info_dict + + for arg in type_args: + child_generics = getattr(arg, "__pydantic_generic_metadata__", {}) + + if child_generics.get("origin"): + origin_cls_name, sub_dict = _try_get_pydantic_v2_generic_type_info(arg) + info_dict[(origin_class.__module__, origin_cls_name)].append(sub_dict) + else: + info_dict[(origin_class.__module__, origin_class.__name__)].append( + (arg.__module__, arg.__name__) + ) + + return origin_class.__name__, info_dict + + +def _build_generic_pydantic_v2_type( + module_name: str, orign_cls_name: str, type_info_dict: dict +) -> Any: + generic_type_arg_info = type_info_dict[(module_name, orign_cls_name)] + origin_cls = getattr(importlib.import_module(module_name), orign_cls_name) + generic_type_args = [] + for arg_info in generic_type_arg_info: + if isinstance(arg_info, dict): + m, c = next(iter(arg_info)) + type_arg = _build_generic_pydantic_v2_type(m, c, arg_info) + else: + type_arg = getattr(importlib.import_module(arg_info[0]), arg_info[1]) + generic_type_args.append(type_arg) + cls = origin_cls.__class_getitem__(tuple(generic_type_args)) + + return cls + + def _msgpack_default(obj: Any) -> str | ormsgpack.Ext: if hasattr(obj, "model_dump") and callable(obj.model_dump): # pydantic v2 + origin_cls_name, generic_type_info = _try_get_pydantic_v2_generic_type_info( + obj.__class__ + ) return ormsgpack.Ext( EXT_PYDANTIC_V2, _msgpack_enc( @@ -265,6 +314,8 @@ def _msgpack_default(obj: Any) -> str | ormsgpack.Ext: obj.__class__.__name__, obj.model_dump(), "model_validate_json", + origin_cls_name, + generic_type_info, ), ), ) @@ -544,8 +595,12 @@ def _msgpack_ext_hook(code: int, data: bytes) -> Any: tup = ormsgpack.unpackb( data, ext_hook=_msgpack_ext_hook, option=ormsgpack.OPT_NON_STR_KEYS ) - # module, name, kwargs, method - cls = getattr(importlib.import_module(tup[0]), tup[1]) + # module, name, kwargs, method, generic_type_info, origin_cls_name + if len(tup) > 4 and tup[4] and isinstance(tup[5], dict) and len(tup[5]) > 0: + # serialized data has generic type info + cls = _build_generic_pydantic_v2_type(tup[0], tup[4], tup[5]) + else: + cls = getattr(importlib.import_module(tup[0]), tup[1]) try: return cls(**tup[2]) except Exception: diff --git a/libs/checkpoint/tests/test_jsonplus.py b/libs/checkpoint/tests/test_jsonplus.py index 348492ae0b..4192c163b8 100644 --- a/libs/checkpoint/tests/test_jsonplus.py +++ b/libs/checkpoint/tests/test_jsonplus.py @@ -8,6 +8,7 @@ from decimal import Decimal from enum import Enum from ipaddress import IPv4Address +from typing import Generic, TypeVar from zoneinfo import ZoneInfo import dataclasses_json @@ -24,6 +25,14 @@ ) from langgraph.store.base import Item +TInner = TypeVar("TInner", bound=BaseModel) + + +class MyPydanticGeneric(BaseModel, Generic[TInner]): + foo: str + bar: int + inner: TInner + class InnerPydantic(BaseModel): hello: str @@ -469,3 +478,25 @@ def test_serde_jsonplus_pandas_series(series: pd.Series) -> None: result = serde.loads_typed(dumped) assert result.equals(series) + + +def test_serde_jsonplus_with_pydantic_generic() -> None: + instance = MyPydanticGeneric[MyPydanticGeneric[InnerPydantic]]( + foo="foo", + bar=1, + inner=MyPydanticGeneric[InnerPydantic]( + foo="inner-foo", bar=2, inner=InnerPydantic(hello="hello") + ), + ) + + serde = JsonPlusSerializer() + dumped = serde.dumps_typed(instance) + assert dumped[0] == "msgpack" + result = serde.loads_typed(dumped) + assert instance == result + + # json mode is not affected by the fix + serde = JsonPlusSerializer(__unpack_ext_hook__=_msgpack_ext_hook_to_json) + + json_result = serde.loads_typed(dumped) + assert json_result == instance.model_dump()