Skip to content

Commit 5929841

Browse files
authored
fix: RedisStorage: MemoryRecord Enum values cause JSON serialization … (#3192)
1 parent 39fa004 commit 5929841

File tree

5 files changed

+53
-25
lines changed

5 files changed

+53
-25
lines changed

camel/memories/blocks/chat_history_block.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,10 @@ def retrieve(
9696
if (
9797
record_dicts
9898
and record_dicts[0]['role_at_backend']
99-
in {OpenAIBackendRole.SYSTEM, OpenAIBackendRole.DEVELOPER}
99+
in {
100+
OpenAIBackendRole.SYSTEM.value,
101+
OpenAIBackendRole.DEVELOPER.value,
102+
}
100103
)
101104
else 0
102105
)

camel/memories/records.py

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
# Enables postponed evaluation of annotations (for string-based type hints)
1616
from __future__ import annotations
1717

18+
import inspect
1819
import time
19-
from dataclasses import asdict
2020
from typing import Any, ClassVar, Dict
2121
from uuid import UUID, uuid4
2222

@@ -63,37 +63,62 @@ class MemoryRecord(BaseModel):
6363
"FunctionCallingMessage": FunctionCallingMessage,
6464
}
6565

66+
# Cache for constructor parameters (performance optimization)
67+
_constructor_params_cache: ClassVar[Dict[str, set]] = {}
68+
69+
@classmethod
70+
def _get_constructor_params(cls, message_cls) -> set:
71+
"""Get constructor parameters for a message class with caching."""
72+
cls_name = message_cls.__name__
73+
if cls_name not in cls._constructor_params_cache:
74+
sig = inspect.signature(message_cls.__init__)
75+
cls._constructor_params_cache[cls_name] = set(
76+
sig.parameters.keys()
77+
) - {'self'}
78+
return cls._constructor_params_cache[cls_name]
79+
6680
@classmethod
6781
def from_dict(cls, record_dict: Dict[str, Any]) -> "MemoryRecord":
6882
r"""Reconstruct a :obj:`MemoryRecord` from the input dict.
6983
7084
Args:
7185
record_dict(Dict[str, Any]): A dict generated by :meth:`to_dict`.
7286
"""
73-
from camel.types import (
74-
OpenAIBackendRole,
75-
RoleType,
76-
)
87+
from camel.types import OpenAIBackendRole, RoleType
7788

7889
message_cls = cls._MESSAGE_TYPES[record_dict["message"]["__class__"]]
79-
kwargs: Dict = record_dict["message"].copy()
80-
kwargs.pop("__class__")
81-
82-
# Convert role_type string back to RoleType enum if it's a string
83-
if "role_type" in kwargs and isinstance(kwargs["role_type"], str):
84-
kwargs["role_type"] = RoleType(kwargs["role_type"])
85-
86-
reconstructed_message = message_cls(**kwargs)
87-
88-
# Convert role_at_backend string back to OpenAIBackendRole enum if
89-
# it's a string
90+
data = record_dict["message"].copy()
91+
data.pop("__class__")
92+
93+
# Convert role_type string to enum
94+
if "role_type" in data and isinstance(data["role_type"], str):
95+
data["role_type"] = RoleType(data["role_type"])
96+
97+
# Get valid constructor parameters (cached)
98+
valid_params = cls._get_constructor_params(message_cls)
99+
100+
# Separate constructor args from extra fields
101+
kwargs = {k: v for k, v in data.items() if k in valid_params}
102+
extra_fields = {k: v for k, v in data.items() if k not in valid_params}
103+
104+
# Handle meta_dict properly: merge existing meta_dict with extra fields
105+
existing_meta = kwargs.get("meta_dict", {}) or {}
106+
if extra_fields:
107+
# Extra fields take precedence, but preserve existing meta_dict
108+
# structure
109+
merged_meta = {**existing_meta, **extra_fields}
110+
kwargs["meta_dict"] = merged_meta
111+
elif not existing_meta:
112+
kwargs["meta_dict"] = None
113+
114+
# Convert role_at_backend
90115
role_at_backend = record_dict["role_at_backend"]
91116
if isinstance(role_at_backend, str):
92117
role_at_backend = OpenAIBackendRole(role_at_backend)
93118

94119
return cls(
95120
uuid=UUID(record_dict["uuid"]),
96-
message=reconstructed_message,
121+
message=message_cls(**kwargs),
97122
role_at_backend=role_at_backend,
98123
extra_info=record_dict["extra_info"],
99124
timestamp=record_dict["timestamp"],
@@ -108,7 +133,7 @@ def to_dict(self) -> Dict[str, Any]:
108133
"uuid": str(self.uuid),
109134
"message": {
110135
"__class__": self.message.__class__.__name__,
111-
**asdict(self.message),
136+
**self.message.to_dict(),
112137
},
113138
"role_at_backend": self.role_at_backend.value
114139
if hasattr(self.role_at_backend, "value")

camel/messages/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ def to_dict(self) -> Dict:
554554
"""
555555
return {
556556
"role_name": self.role_name,
557-
"role_type": self.role_type.name,
557+
"role_type": self.role_type.value,
558558
**(self.meta_dict or {}),
559559
"content": self.content,
560560
}

test/messages/test_chat_message.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_chat_message(chat_message: BaseMessage) -> None:
6363
dictionary = chat_message.to_dict()
6464
reference_dict: Dict[str, Any] = {
6565
"role_name": role_name,
66-
"role_type": role_type.name,
66+
"role_type": role_type.value,
6767
"content": content,
6868
}
6969
assert dictionary == reference_dict
@@ -83,7 +83,7 @@ def test_assistant_chat_message(assistant_chat_message: BaseMessage) -> None:
8383
dictionary = assistant_chat_message.to_dict()
8484
reference_dict: Dict[str, Any] = {
8585
"role_name": role_name,
86-
"role_type": role_type.name,
86+
"role_type": role_type.value,
8787
"content": content,
8888
}
8989
assert dictionary == reference_dict
@@ -103,7 +103,7 @@ def test_user_chat_message(user_chat_message: BaseMessage) -> None:
103103
dictionary = user_chat_message.to_dict()
104104
reference_dict: Dict[str, Any] = {
105105
"role_name": role_name,
106-
"role_type": role_type.name,
106+
"role_type": role_type.value,
107107
"content": content,
108108
}
109109
assert dictionary == reference_dict

test/messages/test_message_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def test_extract_text_and_code_prompts():
8888
def test_base_message_to_dict(base_message: BaseMessage) -> None:
8989
expected_dict = {
9090
"role_name": "test_user",
91-
"role_type": "USER",
91+
"role_type": "user",
9292
"key": "value",
9393
"content": "test content",
9494
}
@@ -132,7 +132,7 @@ def test_base_message():
132132
dictionary = message.to_dict()
133133
assert dictionary == {
134134
"role_name": role_name,
135-
"role_type": role_type.name,
135+
"role_type": role_type.value,
136136
**(meta_dict or {}),
137137
"content": content,
138138
}

0 commit comments

Comments
 (0)