Skip to content

Commit 5b98c32

Browse files
authored
feat: Add get_data_parts() and get_file_parts() helper methods (#312)
Fixes #311 🦕
1 parent 6a0a7da commit 5b98c32

File tree

3 files changed

+174
-0
lines changed

3 files changed

+174
-0
lines changed

src/a2a/utils/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
create_task_obj,
1818
)
1919
from a2a.utils.message import (
20+
get_data_parts,
21+
get_file_parts,
2022
get_message_text,
2123
get_text_parts,
2224
new_agent_parts_message,
@@ -37,6 +39,8 @@
3739
'build_text_artifact',
3840
'completed_task',
3941
'create_task_obj',
42+
'get_data_parts',
43+
'get_file_parts',
4044
'get_message_text',
4145
'get_text_parts',
4246
'new_agent_parts_message',

src/a2a/utils/message.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,13 @@
22

33
import uuid
44

5+
from typing import Any
6+
57
from a2a.types import (
8+
DataPart,
9+
FilePart,
10+
FileWithBytes,
11+
FileWithUri,
612
Message,
713
Part,
814
Role,
@@ -70,6 +76,30 @@ def get_text_parts(parts: list[Part]) -> list[str]:
7076
return [part.root.text for part in parts if isinstance(part.root, TextPart)]
7177

7278

79+
def get_data_parts(parts: list[Part]) -> list[dict[str, Any]]:
80+
"""Extracts dictionary data from all DataPart objects in a list of Parts.
81+
82+
Args:
83+
parts: A list of `Part` objects.
84+
85+
Returns:
86+
A list of dictionaries containing the data from any `DataPart` objects found.
87+
"""
88+
return [part.root.data for part in parts if isinstance(part.root, DataPart)]
89+
90+
91+
def get_file_parts(parts: list[Part]) -> list[FileWithBytes | FileWithUri]:
92+
"""Extracts file data from all FilePart objects in a list of Parts.
93+
94+
Args:
95+
parts: A list of `Part` objects.
96+
97+
Returns:
98+
A list of `FileWithBytes` or `FileWithUri` objects containing the file data from any `FilePart` objects found.
99+
"""
100+
return [part.root.file for part in parts if isinstance(part.root, FilePart)]
101+
102+
73103
def get_message_text(message: Message, delimiter: str = '\n') -> str:
74104
"""Extracts and joins all text content from a Message's parts.
75105

tests/utils/test_message.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,17 @@
44

55
from a2a.types import (
66
DataPart,
7+
FilePart,
8+
FileWithBytes,
9+
FileWithUri,
710
Message,
811
Part,
912
Role,
1013
TextPart,
1114
)
1215
from a2a.utils.message import (
16+
get_data_parts,
17+
get_file_parts,
1318
get_message_text,
1419
get_text_parts,
1520
new_agent_parts_message,
@@ -178,6 +183,141 @@ def test_get_text_parts_empty_list(self):
178183
assert result == []
179184

180185

186+
class TestGetDataParts:
187+
def test_get_data_parts_single_data_part(self):
188+
# Setup
189+
parts = [Part(root=DataPart(data={'key': 'value'}))]
190+
191+
# Exercise
192+
result = get_data_parts(parts)
193+
194+
# Verify
195+
assert result == [{'key': 'value'}]
196+
197+
def test_get_data_parts_multiple_data_parts(self):
198+
# Setup
199+
parts = [
200+
Part(root=DataPart(data={'key1': 'value1'})),
201+
Part(root=DataPart(data={'key2': 'value2'})),
202+
]
203+
204+
# Exercise
205+
result = get_data_parts(parts)
206+
207+
# Verify
208+
assert result == [{'key1': 'value1'}, {'key2': 'value2'}]
209+
210+
def test_get_data_parts_mixed_parts(self):
211+
# Setup
212+
parts = [
213+
Part(root=TextPart(text='some text')),
214+
Part(root=DataPart(data={'key1': 'value1'})),
215+
Part(root=DataPart(data={'key2': 'value2'})),
216+
]
217+
218+
# Exercise
219+
result = get_data_parts(parts)
220+
221+
# Verify
222+
assert result == [{'key1': 'value1'}, {'key2': 'value2'}]
223+
224+
def test_get_data_parts_no_data_parts(self):
225+
# Setup
226+
parts = [
227+
Part(root=TextPart(text='some text')),
228+
]
229+
230+
# Exercise
231+
result = get_data_parts(parts)
232+
233+
# Verify
234+
assert result == []
235+
236+
def test_get_data_parts_empty_list(self):
237+
# Setup
238+
parts = []
239+
240+
# Exercise
241+
result = get_data_parts(parts)
242+
243+
# Verify
244+
assert result == []
245+
246+
247+
class TestGetFileParts:
248+
def test_get_file_parts_single_file_part(self):
249+
# Setup
250+
file_with_uri = FileWithUri(
251+
uri='file://path/to/file', mimeType='text/plain'
252+
)
253+
parts = [Part(root=FilePart(file=file_with_uri))]
254+
255+
# Exercise
256+
result = get_file_parts(parts)
257+
258+
# Verify
259+
assert result == [file_with_uri]
260+
261+
def test_get_file_parts_multiple_file_parts(self):
262+
# Setup
263+
file_with_uri1 = FileWithUri(
264+
uri='file://path/to/file1', mimeType='text/plain'
265+
)
266+
file_with_bytes = FileWithBytes(
267+
bytes='ZmlsZSBjb250ZW50',
268+
mimeType='application/octet-stream', # 'file content'
269+
)
270+
parts = [
271+
Part(root=FilePart(file=file_with_uri1)),
272+
Part(root=FilePart(file=file_with_bytes)),
273+
]
274+
275+
# Exercise
276+
result = get_file_parts(parts)
277+
278+
# Verify
279+
assert result == [file_with_uri1, file_with_bytes]
280+
281+
def test_get_file_parts_mixed_parts(self):
282+
# Setup
283+
file_with_uri = FileWithUri(
284+
uri='file://path/to/file', mimeType='text/plain'
285+
)
286+
parts = [
287+
Part(root=TextPart(text='some text')),
288+
Part(root=FilePart(file=file_with_uri)),
289+
]
290+
291+
# Exercise
292+
result = get_file_parts(parts)
293+
294+
# Verify
295+
assert result == [file_with_uri]
296+
297+
def test_get_file_parts_no_file_parts(self):
298+
# Setup
299+
parts = [
300+
Part(root=TextPart(text='some text')),
301+
Part(root=DataPart(data={'key': 'value'})),
302+
]
303+
304+
# Exercise
305+
result = get_file_parts(parts)
306+
307+
# Verify
308+
assert result == []
309+
310+
def test_get_file_parts_empty_list(self):
311+
# Setup
312+
parts = []
313+
314+
# Exercise
315+
result = get_file_parts(parts)
316+
317+
# Verify
318+
assert result == []
319+
320+
181321
class TestGetMessageText:
182322
def test_get_message_text_single_part(self):
183323
# Setup

0 commit comments

Comments
 (0)