Skip to content

Commit 703b165

Browse files
aelisseekibergus
authored andcommitted
Add a to_genai_content function to content_api to convert ProcessorContent to genai types.
PiperOrigin-RevId: 818528069
1 parent 9a5e34d commit 703b165

File tree

3 files changed

+81
-25
lines changed

3 files changed

+81
-25
lines changed

genai_processors/content_api.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import dataclasses
1919
import functools
2020
import io
21+
import itertools
2122
import json
2223
from typing import Any, TypeVar
2324

@@ -413,15 +414,13 @@ def from_dict(cls, *, data: dict[str, Any]) -> 'ProcessorPart':
413414
414415
Args:
415416
data: A JSON-compatible dictionary containing the serialized data for the
416-
ProcessorPart.
417-
418-
It is expected to have the following keys:
419-
* 'part' (dict): A dictionary representing the underlying
420-
`google.genai.types.Part` object.
421-
* 'role' (str): The role of the part (e.g., 'user', 'model').
422-
* 'substream_name' (str): The substream name.
423-
* 'mimetype' (str): The MIME type of the part.
424-
* 'metadata' (dict[str, Any]): Auxiliary metadata.
417+
ProcessorPart. It is expected to have the following keys:
418+
* 'part' (dict): A dictionary representing the underlying
419+
`google.genai.types.Part` object.
420+
* 'role' (str): The role of the part (e.g., 'user', 'model').
421+
* 'substream_name' (str): The substream name.
422+
* 'mimetype' (str): The MIME type of the part.
423+
* 'metadata' (dict[str, Any]): Auxiliary metadata.
425424
426425
Returns:
427426
A new ProcessorPart instance.
@@ -438,7 +437,7 @@ def from_dict(cls, *, data: dict[str, Any]) -> 'ProcessorPart':
438437
reconstructed = ProcessorPart.from_dict(data=part_as_dict)
439438
print(reconstructed)
440439
```
441-
"""
440+
""" # fmt: skip
442441
return cls(
443442
genai_types.Part.model_validate(data['part']),
444443
role=data.get('role', ''),
@@ -830,3 +829,29 @@ def to_genai_part(
830829
raise ValueError(
831830
f'Unsupported type for to_genai_part: {type(part_content)}'
832831
)
832+
833+
834+
def to_genai_contents(
835+
content: ProcessorContentTypes,
836+
) -> list[genai_types.Content]:
837+
"""Converts a list of ProcessorParts into a list of Genai Content objects.
838+
839+
Consecutive parts with the same role are grouped together into a single
840+
`genai_types.Content` object.
841+
842+
Args:
843+
content: Processor content, e.g. a list of `ProcessorPartTypes`.
844+
845+
Returns:
846+
A list of `genai_types.Content` objects, where each object represents
847+
content from a single role.
848+
"""
849+
processor_content = ProcessorContent(content)
850+
contents = []
851+
for role, content_part in itertools.groupby(
852+
processor_content, lambda p: p.role
853+
):
854+
content_parts = [p.part for p in content_part]
855+
contents.append(genai_types.Content(parts=content_parts, role=role))
856+
857+
return contents

genai_processors/core/genai_model.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
from typing import Any
5757
from genai_processors import content_api
5858
from genai_processors import processor
59+
from genai_processors import streams
5960
from genai_processors.core import constrained_decoding
6061
from google.genai import client
6162
from google.genai import types as genai_types
@@ -164,26 +165,13 @@ async def _generate_from_api(
164165
self, content: AsyncIterable[content_api.ProcessorPartTypes]
165166
) -> AsyncIterable[content_api.ProcessorPartTypes]:
166167
"""Internal method to call the GenAI API and stream results."""
167-
turn = genai_types.Content(parts=[])
168-
contents = []
169-
async for content_part in content:
170-
content_part = content_api.ProcessorPart(content_part)
171-
if turn.role and content_part.role != turn.role:
172-
contents.append(turn)
173-
turn = genai_types.Content(parts=[])
174-
175-
turn.role = content_part.role or 'user'
176-
turn.parts.append(content_api.to_genai_part(content_part)) # pylint: disable=attribute-error
177-
178-
if turn.role:
179-
contents.append(turn)
180-
168+
contents = await streams.gather_stream(content)
181169
if not contents:
182170
return
183171

184172
async for res in await self._client.aio.models.generate_content_stream(
185173
model=self._model_name,
186-
contents=contents,
174+
contents=content_api.to_genai_contents(contents),
187175
config=self._generate_content_config,
188176
):
189177
res: genai_types.GenerateContentResponse = res

genai_processors/tests/content_api_test.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,5 +407,48 @@ def test_eq_content_and_non_content(self):
407407
content = content_api.ProcessorContent('foo')
408408
self.assertNotEqual(content, object())
409409

410+
def test_to_genai_contents(self):
411+
parts = [
412+
content_api.ProcessorPart('part1', role='user'),
413+
content_api.ProcessorPart('part2', role='user'),
414+
content_api.ProcessorPart('part3', role='model'),
415+
content_api.ProcessorPart('part4', role='user'),
416+
]
417+
expected_genai_contents = [
418+
genai_types.Content(
419+
parts=[
420+
genai_types.Part(text='part1'),
421+
genai_types.Part(text='part2'),
422+
],
423+
role='user',
424+
),
425+
genai_types.Content(
426+
parts=[genai_types.Part(text='part3')],
427+
role='model',
428+
),
429+
genai_types.Content(
430+
parts=[genai_types.Part(text='part4')],
431+
role='user',
432+
),
433+
]
434+
genai_contents = content_api.to_genai_contents(parts)
435+
self.assertEqual(genai_contents, expected_genai_contents)
436+
437+
def test_to_genai_contents_empty(self):
438+
genai_contents = content_api.to_genai_contents([])
439+
self.assertEmpty(genai_contents)
440+
441+
def test_to_genai_contents_single_part(self):
442+
parts = [content_api.ProcessorPart('part1', role='user')]
443+
genai_contents = content_api.to_genai_contents(parts)
444+
expected_genai_contents = [
445+
genai_types.Content(
446+
parts=[genai_types.Part(text='part1')],
447+
role='user',
448+
),
449+
]
450+
self.assertEqual(genai_contents, expected_genai_contents)
451+
452+
410453
if __name__ == '__main__':
411454
absltest.main()

0 commit comments

Comments
 (0)