Skip to content

Commit ebcf642

Browse files
committed
Rollback making ProcessorPart to inherit from genai.Part
The initial transition has been made in a faith that frameworks can unify on a shared content representation and avoid unnecessary conversions. However it requires that representation to be flexible enough to accomodate needs of its users. Unfortunately this attempt has failed and adding anything not directly needed for Gemini models to genai.Part is nearly impossible. PiperOrigin-RevId: 828318675
1 parent 039dda6 commit ebcf642

File tree

3 files changed

+88
-53
lines changed

3 files changed

+88
-53
lines changed

genai_processors/content_api.py

Lines changed: 71 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,9 @@
2626
from genai_processors import mime_types
2727
from google.genai import types as genai_types
2828
import PIL.Image
29-
import pydantic
3029

3130

32-
class ProcessorPart(genai_types.Part):
31+
class ProcessorPart:
3332
"""A wrapper around `Part` with additional metadata.
3433
3534
Represents a single piece of content that can be processed by an agentic
@@ -39,11 +38,6 @@ class ProcessorPart(genai_types.Part):
3938
belongs to, the MIME type of the content, and arbitrary metadata.
4039
"""
4140

42-
_metadata: dict[str, Any] = pydantic.PrivateAttr(default_factory=dict)
43-
_role: str = pydantic.PrivateAttr(default='')
44-
_substream_name: str = pydantic.PrivateAttr(default='')
45-
_mimetype: str = pydantic.PrivateAttr(default='')
46-
4741
def __init__(
4842
self,
4943
value: 'ProcessorPartTypes',
@@ -56,43 +50,48 @@ def __init__(
5650
"""Constructs a ProcessorPart.
5751
5852
Args:
59-
value: The content to use to construct the ProcessorPart. Any keyword
60-
arguments after this one overrides any properties in value.
53+
value: The content to use to construct the ProcessorPart.
6154
role: Optional. The producer of the content. In Genai models, must be
6255
either 'user' or 'model', but the user can set their own semantics.
6356
Useful to set for multi-turn conversations, otherwise can be empty.
6457
substream_name: (Optional) ProcessorPart stream can be split into multiple
6558
independent streams. They may have specific semantics, e.g. a song and
6659
its lyrics, or can be just alternative responses. Prefer using a default
67-
substream with an empty name.
60+
substream with an empty name. If the `ProcessorPart` is created using
61+
another `ProcessorPart`, this ProcessorPart inherits the existing
62+
substream_name, unless it is overridden in this argument.
6863
mimetype: Mime type of the data.
6964
metadata: (Optional) Auxiliary information about the part. If the
7065
`ProcessorPart` is created using another `ProcessorPart` or a
7166
`content_pb2.Part`, this ProcessorPart inherits the existing metadata,
7267
unless it is overridden in this argument.
7368
"""
69+
super().__init__()
70+
self._metadata = {}
71+
7472
match value:
73+
case genai_types.Part():
74+
self._part = value
7575
case ProcessorPart():
76-
super().__init__(**value.model_dump(exclude_unset=True))
76+
self._part = value.part
7777
role = role or value.role
7878
substream_name = substream_name or value.substream_name
7979
mimetype = mimetype or value.mimetype
8080
self._metadata = value.metadata
81-
case genai_types.Part():
82-
super().__init__(**value.model_dump(exclude_unset=True))
81+
self._metadata.update(metadata or {})
8382
case str():
84-
super().__init__(text=value)
83+
self._part = genai_types.Part(text=value)
8584
case bytes():
8685
if not mimetype:
8786
raise ValueError(
8887
'MIME type must be specified when constructing a ProcessorPart'
8988
' from bytes.'
9089
)
9190
if is_text(mimetype):
92-
super().__init__(text=value.decode('utf-8'))
91+
self._part = genai_types.Part(text=value.decode('utf-8'))
9392
else:
94-
super().__init__(
95-
inline_data=genai_types.Blob(data=value, mime_type=mimetype)
93+
self._part = genai_types.Part.from_bytes(
94+
data=value, mime_type=mimetype
9695
)
9796
case PIL.Image.Image():
9897
if mimetype:
@@ -114,10 +113,8 @@ def __init__(
114113
mimetype = f'image/{suffix}'
115114
bytes_io = io.BytesIO()
116115
value.save(bytes_io, suffix.upper())
117-
super().__init__(
118-
inline_data=genai_types.Blob(
119-
data=bytes_io.getvalue(), mime_type=mimetype
120-
)
116+
self._part = genai_types.Part.from_bytes(
117+
data=bytes_io.getvalue(), mime_type=mimetype
121118
)
122119
case _:
123120
raise ValueError(f"Can't construct ProcessorPart from {type(value)}.")
@@ -130,10 +127,10 @@ def __init__(
130127
if mimetype:
131128
self._mimetype = mimetype
132129
# Otherwise, if MIME type is specified using inline data, use that.
133-
elif self.inline_data and self.inline_data.mime_type:
134-
self._mimetype = self.inline_data.mime_type
130+
elif self._part.inline_data and self._part.inline_data.mime_type:
131+
self._mimetype = self._part.inline_data.mime_type
135132
# Otherwise, if text is not empty, assume 'text/plain' MIME type.
136-
elif self.text:
133+
elif self._part.text:
137134
self._mimetype = 'text/plain'
138135
else:
139136
self._mimetype = ''
@@ -147,23 +144,24 @@ def __repr__(self) -> str:
147144
if self.role:
148145
optional_args += f', role={self.role!r}'
149146
return (
150-
f'ProcessorPart({self.to_json_dict()!r},'
147+
f'ProcessorPart({self.part.to_json_dict()!r},'
151148
f' mimetype={self.mimetype!r}{optional_args})'
152149
)
153150

154151
def __eq__(self, other: Any) -> bool:
155152
if not isinstance(other, ProcessorPart):
156153
return False
157-
return self.__dict__ == other.__dict__
154+
return (
155+
self._part == other._part
156+
and self._role.lower() == other._role.lower()
157+
and self._substream_name.lower() == other._substream_name.lower()
158+
and self._metadata == other._metadata
159+
)
158160

159161
@property
160162
def part(self) -> genai_types.Part:
161-
"""Returns the underlying Genai Part.
162-
163-
DEPRECATED: Use the ProcessorPart itself, it now inherits from genai.Part.
164-
This property is provided for backward compatibility reasons.
165-
"""
166-
return self
163+
"""Returns the underlying Genai Part."""
164+
return self._part
167165

168166
@property
169167
def role(self) -> str:
@@ -189,10 +187,10 @@ def bytes(self) -> bytes | None:
189187
Text encoded into bytes or bytes from inline data if the underlying part
190188
is a Blob.
191189
"""
192-
if self.text:
190+
if self.part.text:
193191
return self.text.encode()
194-
if isinstance(self.inline_data, genai_types.Blob):
195-
return self.inline_data.data
192+
if isinstance(self.part.inline_data, genai_types.Blob):
193+
return self.part.inline_data.data
196194
return None
197195

198196
@property
@@ -215,6 +213,25 @@ def mimetype(self) -> str:
215213
"""
216214
return self._mimetype or 'text/plain'
217215

216+
@property
217+
def text(self) -> str:
218+
"""Returns part text as string.
219+
220+
Returns:
221+
The text of the part.
222+
223+
Raises:
224+
ValueError if part has no text.
225+
"""
226+
if not mime_types.is_text(self.mimetype):
227+
raise ValueError('Part is not text.')
228+
return self.part.text or ''
229+
230+
@text.setter
231+
def text(self, value: str) -> None:
232+
"""Sets part to a text part."""
233+
self._part = genai_types.Part(text=value)
234+
218235
@property
219236
def metadata(self) -> dict[str, Any]:
220237
"""Returns metadata."""
@@ -229,6 +246,16 @@ def get_metadata(self, key: str, default=None) -> Any:
229246
"""Returns metadata for a given key."""
230247
return self._metadata.get(key, default)
231248

249+
@property
250+
def function_call(self) -> genai_types.FunctionCall | None:
251+
"""Returns function call."""
252+
return self.part.function_call
253+
254+
@property
255+
def function_response(self) -> genai_types.FunctionResponse | None:
256+
"""Returns function response."""
257+
return self.part.function_response
258+
232259
@property
233260
def tool_cancellation(self) -> str | None:
234261
"""Returns an id of a function call to be cancelled.
@@ -239,13 +266,13 @@ def tool_cancellation(self) -> str | None:
239266
The id of the function call to be cancelled or None if this part is not a
240267
tool cancellation from the model.
241268
"""
242-
if not self.function_response:
269+
if not self.part.function_response:
243270
return None
244-
if self.function_response.name != 'tool_cancellation':
271+
if self.part.function_response.name != 'tool_cancellation':
245272
return None
246-
if not self.function_response.response:
273+
if not self.part.function_response.response:
247274
return None
248-
return self.function_response.response.get('function_call_id', None)
275+
return self.part.function_response.response.get('function_call_id', None)
249276

250277
T = TypeVar('T')
251278

@@ -274,8 +301,8 @@ def pil_image(self) -> PIL.Image.Image:
274301
if not mime_types.is_image(self.mimetype):
275302
raise ValueError(f'Part is not an image. Mime type is {self.mimetype}.')
276303
bytes_io = io.BytesIO()
277-
if self.inline_data is not None:
278-
bytes_io.write(self.inline_data.data)
304+
if self.part.inline_data is not None:
305+
bytes_io.write(self.part.inline_data.data)
279306
bytes_io.seek(0)
280307
return PIL.Image.open(bytes_io)
281308

@@ -452,7 +479,7 @@ def to_dict(self) -> dict[str, Any]:
452479
```
453480
"""
454481
return {
455-
'part': self.model_dump(mode='json', exclude_none=True),
482+
'part': self.part.model_dump(mode='json', exclude_none=True),
456483
'role': self.role,
457484
'substream_name': self.substream_name,
458485
'mimetype': self.mimetype,
@@ -826,9 +853,10 @@ def to_genai_contents(
826853
"""
827854
processor_content = ProcessorContent(content)
828855
contents = []
829-
for role, content_parts in itertools.groupby(
856+
for role, content_part in itertools.groupby(
830857
processor_content, lambda p: p.role
831858
):
859+
content_parts = [p.part for p in content_part]
832860
contents.append(genai_types.Content(parts=content_parts, role=role))
833861

834862
return contents

genai_processors/tests/content_api_test.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,15 @@ def test_from_bytes_with_text_mimetype(self):
7272
mimetype = 'text/plain'
7373
part = content_api.ProcessorPart(bytes_data, mimetype=mimetype)
7474
self.assertEqual(part.text, 'hello')
75+
self.assertEqual(part.part.text, 'hello')
7576

7677
def test_from_bytes_with_non_text_mimetype(self):
7778
bytes_data = b'hello'
7879
mimetype = 'application/octet-stream'
7980
part = content_api.ProcessorPart(bytes_data, mimetype=mimetype)
8081
self.assertEqual(part.bytes, bytes_data)
81-
self.assertEqual(part.inline_data.data, bytes_data)
82-
self.assertEqual(part.inline_data.mime_type, mimetype)
82+
self.assertEqual(part.part.inline_data.data, bytes_data)
83+
self.assertEqual(part.part.inline_data.mime_type, mimetype)
8384

8485
def test_eq_part_and_non_part(self):
8586
part = content_api.ProcessorPart('foo')
@@ -416,17 +417,17 @@ def test_to_genai_contents(self):
416417
expected_genai_contents = [
417418
genai_types.Content(
418419
parts=[
419-
content_api.ProcessorPart('part1'),
420-
content_api.ProcessorPart('part2'),
420+
genai_types.Part(text='part1'),
421+
genai_types.Part(text='part2'),
421422
],
422423
role='user',
423424
),
424425
genai_types.Content(
425-
parts=[content_api.ProcessorPart('part3')],
426+
parts=[genai_types.Part(text='part3')],
426427
role='model',
427428
),
428429
genai_types.Content(
429-
parts=[content_api.ProcessorPart('part4')],
430+
parts=[genai_types.Part(text='part4')],
430431
role='user',
431432
),
432433
]
@@ -442,7 +443,7 @@ def test_to_genai_contents_single_part(self):
442443
genai_contents = content_api.to_genai_contents(parts)
443444
expected_genai_contents = [
444445
genai_types.Content(
445-
parts=[content_api.ProcessorPart('part1')],
446+
parts=[genai_types.Part(text='part1')],
446447
role='user',
447448
),
448449
]

genai_processors/tests/function_calling_test.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ async def test_one_function_call(self):
8282
output,
8383
[
8484
content_api.ProcessorPart.from_function_call(
85-
name='get_weather', args={'location': 'London'}
85+
name='get_weather',
86+
args={'location': 'London'},
87+
substream_name=function_calling.FUNCTION_CALL_SUBTREAM_NAME,
8688
),
8789
content_api.ProcessorPart.from_function_response(
8890
name='get_weather',
@@ -179,7 +181,9 @@ async def test_max_function_calls(self):
179181
output,
180182
[
181183
content_api.ProcessorPart.from_function_call(
182-
name='get_time', args={}
184+
name='get_time',
185+
args={},
186+
substream_name=function_calling.FUNCTION_CALL_SUBTREAM_NAME,
183187
),
184188
content_api.ProcessorPart.from_function_response(
185189
name='get_time',
@@ -226,7 +230,9 @@ async def test_failing_function(self):
226230
output,
227231
[
228232
content_api.ProcessorPart.from_function_call(
229-
name='failing_function', args={}
233+
name='failing_function',
234+
args={},
235+
substream_name=function_calling.FUNCTION_CALL_SUBTREAM_NAME,
230236
),
231237
content_api.ProcessorPart.from_function_response(
232238
name='failing_function',

0 commit comments

Comments
 (0)