Skip to content

Commit 90f6365

Browse files
committed
Fix: #29 Enhance serialization functions to handle nested dataclasses and unexpected fields in messages
1 parent 0243f13 commit 90f6365

File tree

3 files changed

+73
-10
lines changed

3 files changed

+73
-10
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1717
- Remove dependency on community driver
1818
- Use `iris` module instead
1919

20+
### Fixed
21+
- Fix dataclass message serialization
22+
- Go back to best effort serialization type check is not forced
23+
2024
## [3.4.0] - 2025-03-24
2125

2226
### Added

src/iop/_serialization.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import inspect
55
import pickle
66
import json
7-
from dataclasses import asdict, is_dataclass
7+
from dataclasses import is_dataclass
88
from typing import Any, Dict, Type
99

1010
from . import _iris
@@ -149,12 +149,20 @@ def process_field(value: Any, field_type: Type) -> Any:
149149
def dataclass_to_dict(instance: Any) -> Dict:
150150
"""Converts a class instance to a dictionary.
151151
Handles non attended fields."""
152-
dikt = asdict(instance)
153-
# assign any extra fields
154-
for k, v in vars(instance).items():
155-
if k not in dikt:
156-
dikt[k] = v
157-
return dikt
152+
result = {}
153+
for field in instance.__dict__:
154+
value = getattr(instance, field)
155+
if is_dataclass(value):
156+
result[field] = dataclass_to_dict(value)
157+
elif isinstance(value, list):
158+
result[field] = [dataclass_to_dict(i) if is_dataclass(i) else i for i in value]
159+
elif isinstance(value, dict):
160+
result[field] = {k: dataclass_to_dict(v) if is_dataclass(v) else v for k, v in value.items()}
161+
elif hasattr(value, '__dict__'):
162+
result[field] = dataclass_to_dict(value)
163+
else:
164+
result[field] = value
165+
return result
158166

159167
# Maintain backwards compatibility
160168
serialize_pickle_message = lambda msg: MessageSerializer.serialize(msg, use_pickle=True)

src/tests/test_serialization.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
deserialize_pickle_message,
1616
)
1717

18+
class NonDataclass:
19+
def __init__(self, value):
20+
self.value = value
21+
1822
@dataclass
1923
class Object:
2024
value: str
@@ -33,6 +37,9 @@ class FullMessge:
3337
uid: uuid.UUID
3438
data: bytes
3539
items: list # Changed from df to a simple list
40+
list_obj: list[Object] = None
41+
dict_obj: dict[str, Object] = None
42+
optional_obj: Optional[Object] = None
3643

3744
@dataclass
3845
class MyObject:
@@ -47,8 +54,6 @@ class Msg:
4754
my_obj: MyObject
4855

4956
def test_message_serialization():
50-
51-
5257
msg = Msg(text="hello", number=42, my_obj=None)
5358

5459
my_obj = MyObject(value="test", foo=None)
@@ -71,6 +76,46 @@ def test_message_serialization():
7176
assert result.number == msg.number
7277
assert result.my_obj == my_obj
7378

79+
def test_unexpexted_obj_serialization():
80+
# Create an invalid message
81+
msg = Msg(text="hello", number=42, my_obj=None)
82+
msg.my_obj = NonDataclass(value="test")
83+
84+
# Test serialization
85+
serial = serialize_message(msg)
86+
assert type(serial).__module__.startswith('iris') and serial._IsA("IOP.Message")
87+
assert serial.classname == f"{Msg.__module__}.{Msg.__name__}"
88+
89+
# Test deserialization
90+
result = deserialize_message(serial)
91+
assert isinstance(result, Msg)
92+
assert result.text == msg.text
93+
assert result.number == msg.number
94+
assert result.my_obj.value == msg.my_obj.value
95+
96+
97+
def test_unexpected_fields():
98+
# Create a message with unexpected fields
99+
msg = Msg(text="hello", number=42, my_obj=None)
100+
msg.unexpected_field = "unexpected"
101+
102+
my_obj = MyObject(value="test", foo=None)
103+
my_obj.unexpected_field = "unexpected"
104+
msg.my_obj = my_obj
105+
106+
# Test serialization
107+
serial = serialize_message(msg)
108+
assert type(serial).__module__.startswith('iris') and serial._IsA("IOP.Message")
109+
assert serial.classname == f"{Msg.__module__}.{Msg.__name__}"
110+
111+
# Test deserialization
112+
result = deserialize_message(serial)
113+
assert isinstance(result, Msg)
114+
assert result.text == msg.text
115+
assert result.number == msg.number
116+
assert result.unexpected_field == msg.unexpected_field
117+
assert result.my_obj == my_obj
118+
74119

75120
def test_json_serialization():
76121
# Create test data
@@ -127,7 +172,10 @@ def test_pickle_serialization():
127172
dec=decimal.Decimal("3.14"),
128173
uid=uuid.uuid4(),
129174
data=b'hello world',
130-
items=[{'col1': 1, 'col2': 'a'}, {'col1': 2, 'col2': 'b'}]
175+
items=[{'col1': 1, 'col2': 'a'}, {'col1': 2, 'col2': 'b'}],
176+
list_obj=[Object(value="item1"), Object(value="item2")],
177+
dict_obj={'key1': Object(value="item1"), 'key2': Object(value="item2")},
178+
optional_obj=Object(value="optional")
131179
)
132180

133181
# Test serialization
@@ -150,6 +198,9 @@ def test_pickle_serialization():
150198
assert result.uid == msg.uid
151199
assert result.data == msg.data
152200
assert result.items == msg.items
201+
assert result.list_obj == msg.list_obj
202+
assert result.dict_obj == msg.dict_obj
203+
assert result.optional_obj == msg.optional_obj
153204

154205
def test_invalid_message_deserialization():
155206
# Create an invalid message without classname

0 commit comments

Comments
 (0)