Skip to content

Commit e581000

Browse files
feat: add to_dict and from_dict to StreamingChunk (#9608)
* add to_dict and from_dict to StreamingChunk * bugfix: set index field * release notes + minor bugfix in tests * ensure full serialization + update tests * support de/serialization for all involved dataclasses * remove unnecessary import * code cleanup + only allow dicts in from_dict * update release notes * add more tests
1 parent 976cb86 commit e581000

File tree

4 files changed

+245
-1
lines changed

4 files changed

+245
-1
lines changed

haystack/dataclasses/chat_message.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,26 @@ class ToolCall:
6060
arguments: Dict[str, Any]
6161
id: Optional[str] = None # noqa: A003
6262

63+
def to_dict(self) -> Dict[str, Any]:
64+
"""
65+
Convert ToolCall into a dictionary.
66+
67+
:returns: A dictionary with keys 'tool_name', 'arguments', and 'id'.
68+
"""
69+
return asdict(self)
70+
71+
@classmethod
72+
def from_dict(cls, data: Dict[str, Any]) -> "ToolCall":
73+
"""
74+
Creates a new ToolCall object from a dictionary.
75+
76+
:param data:
77+
The dictionary to build the ToolCall object.
78+
:returns:
79+
The created object.
80+
"""
81+
return ToolCall(**data)
82+
6383

6484
@dataclass
6585
class ToolCallResult:
@@ -75,6 +95,31 @@ class ToolCallResult:
7595
origin: ToolCall
7696
error: bool
7797

98+
def to_dict(self) -> Dict[str, Any]:
99+
"""
100+
Converts ToolCallResult into a dictionary.
101+
102+
:returns: A dictionary with keys 'result', 'origin', and 'error'.
103+
"""
104+
return asdict(self)
105+
106+
@classmethod
107+
def from_dict(cls, data: Dict[str, Any]) -> "ToolCallResult":
108+
"""
109+
Creates a ToolCallResult from a dictionary.
110+
111+
:param data:
112+
The dictionary to build the ToolCallResult object.
113+
:returns:
114+
The created object.
115+
"""
116+
if not all(x in data for x in ["result", "origin", "error"]):
117+
raise ValueError(
118+
"Fields `result`, `origin`, `error` are required for ToolCallResult deserialization. "
119+
f"Received dictionary with keys {list(data.keys())}"
120+
)
121+
return ToolCallResult(result=data["result"], origin=ToolCall.from_dict(data["origin"]), error=data["error"])
122+
78123

79124
@dataclass
80125
class TextContent:

haystack/dataclasses/streaming_chunk.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
from dataclasses import dataclass, field
5+
from dataclasses import asdict, dataclass, field
66
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Union, overload
77

88
from haystack.core.component import Component
@@ -30,6 +30,24 @@ class ToolCallDelta:
3030
arguments: Optional[str] = field(default=None)
3131
id: Optional[str] = field(default=None) # noqa: A003
3232

33+
def to_dict(self) -> Dict[str, Any]:
34+
"""
35+
Returns a dictionary representation of the ToolCallDelta.
36+
37+
:returns: A dictionary with keys 'index', 'tool_name', 'arguments', and 'id'.
38+
"""
39+
return asdict(self)
40+
41+
@classmethod
42+
def from_dict(cls, data: Dict[str, Any]) -> "ToolCallDelta":
43+
"""
44+
Creates a ToolCallDelta from a serialized representation.
45+
46+
:param data: Dictionary containing ToolCallDelta's attributes.
47+
:returns: A ToolCallDelta instance.
48+
"""
49+
return ToolCallDelta(**data)
50+
3351

3452
@dataclass
3553
class ComponentInfo:
@@ -58,6 +76,24 @@ def from_component(cls, component: Component) -> "ComponentInfo":
5876
component_name = getattr(component, "__component_name__", None)
5977
return cls(type=component_type, name=component_name)
6078

79+
def to_dict(self) -> Dict[str, Any]:
80+
"""
81+
Returns a dictionary representation of ComponentInfo.
82+
83+
:returns: A dictionary with keys 'type' and 'name'.
84+
"""
85+
return asdict(self)
86+
87+
@classmethod
88+
def from_dict(cls, data: Dict[str, Any]) -> "ComponentInfo":
89+
"""
90+
Creates a ComponentInfo from a serialized representation.
91+
92+
:param data: Dictionary containing ComponentInfo's attributes.
93+
:returns: A ComponentInfo instance.
94+
"""
95+
return ComponentInfo(**data)
96+
6197

6298
@dataclass
6399
class StreamingChunk:
@@ -102,6 +138,47 @@ def __post_init__(self):
102138
if (self.tool_calls or self.tool_call_result) and self.index is None:
103139
raise ValueError("If `tool_call`, or `tool_call_result` is set, `index` must also be set.")
104140

141+
def to_dict(self) -> Dict[str, Any]:
142+
"""
143+
Returns a dictionary representation of the StreamingChunk.
144+
145+
:returns: Serialized dictionary representation of the calling object.
146+
"""
147+
return {
148+
"content": self.content,
149+
"meta": self.meta,
150+
"component_info": self.component_info.to_dict() if self.component_info else None,
151+
"index": self.index,
152+
"tool_calls": [tc.to_dict() for tc in self.tool_calls] if self.tool_calls else None,
153+
"tool_call_result": self.tool_call_result.to_dict() if self.tool_call_result else None,
154+
"start": self.start,
155+
"finish_reason": self.finish_reason,
156+
}
157+
158+
@classmethod
159+
def from_dict(cls, data: Dict[str, Any]) -> "StreamingChunk":
160+
"""
161+
Creates a deserialized StreamingChunk instance from a serialized representation.
162+
163+
:param data: Dictionary containing the StreamingChunk's attributes.
164+
:returns: A StreamingChunk instance.
165+
"""
166+
if "content" not in data:
167+
raise ValueError("Missing required field `content` in StreamingChunk deserialization.")
168+
169+
return StreamingChunk(
170+
content=data["content"],
171+
meta=data.get("meta", {}),
172+
component_info=ComponentInfo.from_dict(data["component_info"]) if data.get("component_info") else None,
173+
index=data.get("index"),
174+
tool_calls=[ToolCallDelta.from_dict(tc) for tc in data["tool_calls"]] if data.get("tool_calls") else None,
175+
tool_call_result=ToolCallResult.from_dict(data["tool_call_result"])
176+
if data.get("tool_call_result")
177+
else None,
178+
start=data.get("start", False),
179+
finish_reason=data.get("finish_reason"),
180+
)
181+
105182

106183
SyncStreamingCallbackT = Callable[[StreamingChunk], None]
107184
AsyncStreamingCallbackT = Callable[[StreamingChunk], Awaitable[None]]
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
features:
3+
- |
4+
Add `to_dict` and `from_dict` to classes StreamingChunk, ToolCallResult, ToolCall, ComponentInfo, and ToolCallDelta to make it consistent with our other dataclasses in having serialization and deserialization methods.

test/dataclasses/test_streaming_chunk.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,121 @@ def test_finish_reason_tool_call_results():
135135
assert chunk.finish_reason == "tool_call_results"
136136
assert chunk.meta["finish_reason"] == "tool_call_results"
137137
assert chunk.content == ""
138+
139+
140+
def test_to_dict_tool_call_result():
141+
"""Test the to_dict method for StreamingChunk with tool_call_result."""
142+
component = ExampleComponent()
143+
component_info = ComponentInfo.from_component(component)
144+
tool_call_result = ToolCallResult(
145+
result="output", origin=ToolCall(id="123", tool_name="test_tool", arguments={"arg1": "value1"}), error=False
146+
)
147+
148+
chunk = StreamingChunk(
149+
content="",
150+
meta={"key": "value"},
151+
index=0,
152+
component_info=component_info,
153+
tool_call_result=tool_call_result,
154+
finish_reason="tool_call_results",
155+
)
156+
157+
d = chunk.to_dict()
158+
159+
assert d["content"] == ""
160+
assert d["meta"] == {"key": "value"}
161+
assert d["index"] == 0
162+
assert d["component_info"]["type"] == "test_streaming_chunk.ExampleComponent"
163+
assert d["tool_call_result"]["result"] == "output"
164+
assert d["tool_call_result"]["error"] is False
165+
assert d["tool_call_result"]["origin"]["id"] == "123"
166+
assert d["tool_call_result"]["origin"]["arguments"]["arg1"] == "value1"
167+
assert d["finish_reason"] == "tool_call_results"
168+
169+
170+
def test_to_dict_tool_calls():
171+
"""Test the to_dict method for StreamingChunk with tool_calls."""
172+
component = ExampleComponent()
173+
component_info = ComponentInfo.from_component(component)
174+
tool_calls = [
175+
ToolCallDelta(id="123", tool_name="test_tool", arguments='{"arg1": "value1"}', index=0),
176+
ToolCallDelta(id="456", tool_name="another_tool", arguments='{"arg2": "value2"}', index=1),
177+
]
178+
179+
chunk = StreamingChunk(
180+
content="",
181+
meta={"key": "value"},
182+
index=0,
183+
component_info=component_info,
184+
tool_calls=tool_calls,
185+
finish_reason="tool_calls",
186+
)
187+
188+
d = chunk.to_dict()
189+
190+
assert d["content"] == ""
191+
assert d["meta"] == {"key": "value"}
192+
assert d["index"] == 0
193+
assert d["component_info"]["type"] == "test_streaming_chunk.ExampleComponent"
194+
assert len(d["tool_calls"]) == 2
195+
assert d["tool_calls"][0]["id"] == "123"
196+
assert d["tool_calls"][0]["index"] == 0
197+
assert d["tool_calls"][1]["id"] == "456"
198+
assert d["tool_calls"][1]["index"] == 1
199+
assert d["finish_reason"] == "tool_calls"
200+
201+
202+
def test_from_dict_tool_call_result():
203+
"""Test the from_dict method for StreamingChunk with tool_call_result."""
204+
component_info = {"type": "test_streaming_chunk.ExampleComponent", "name": "test_component"}
205+
tool_call_result = {
206+
"result": "output",
207+
"origin": {"id": "123", "tool_name": "test_tool", "arguments": {"arg1": "value1"}},
208+
"error": False,
209+
}
210+
211+
data = {
212+
"content": "",
213+
"meta": {"key": "value"},
214+
"index": 0,
215+
"component_info": component_info,
216+
"tool_call_result": tool_call_result,
217+
"finish_reason": "tool_call_results",
218+
}
219+
220+
chunk = StreamingChunk.from_dict(data)
221+
222+
assert chunk.content == ""
223+
assert chunk.meta == {"key": "value"}
224+
assert chunk.index == 0
225+
assert chunk.component_info.type == "test_streaming_chunk.ExampleComponent"
226+
assert chunk.component_info.name == "test_component"
227+
assert chunk.tool_call_result.result == "output"
228+
assert chunk.tool_call_result.error is False
229+
assert chunk.tool_call_result.origin.id == "123"
230+
231+
232+
def test_from_dict_tool_calls():
233+
"""Test the from_dict method for StreamingChunk with tool_calls."""
234+
component_info = {"type": "test_streaming_chunk.ExampleComponent", "name": "test_component"}
235+
tool_calls = [{"id": "123", "tool_name": "test_tool", "arguments": '{"arg1": "value1"}', "index": 0}]
236+
237+
data = {
238+
"content": "",
239+
"meta": {"key": "value"},
240+
"index": 0,
241+
"component_info": component_info,
242+
"tool_calls": tool_calls,
243+
"finish_reason": "tool_calls",
244+
}
245+
246+
chunk = StreamingChunk.from_dict(data)
247+
248+
assert chunk.content == ""
249+
assert chunk.meta == {"key": "value"}
250+
assert chunk.index == 0
251+
assert chunk.component_info.type == "test_streaming_chunk.ExampleComponent"
252+
assert chunk.component_info.name == "test_component"
253+
assert chunk.tool_calls[0].tool_name == "test_tool"
254+
assert chunk.tool_calls[0].index == 0
255+
assert chunk.finish_reason == "tool_calls"

0 commit comments

Comments
 (0)