diff --git a/temporalio/contrib/openai_agents/invoke_model_activity.py b/temporalio/contrib/openai_agents/invoke_model_activity.py index 10b7370cb..198517fe2 100644 --- a/temporalio/contrib/openai_agents/invoke_model_activity.py +++ b/temporalio/contrib/openai_agents/invoke_model_activity.py @@ -24,6 +24,7 @@ WebSearchTool, ) from agents.models.multi_provider import MultiProvider +from pydantic_core import to_jsonable_python from typing_extensions import Required, TypedDict from temporalio import activity, workflow @@ -139,7 +140,7 @@ async def empty_on_invoke_handoff( # workaround for https://github.com/pydantic/pydantic/issues/9541 # ValidatorIterator returned - input_json = json.dumps(input["input"], default=str) + input_json = json.dumps(to_jsonable_python(input["input"])) input_input = json.loads(input_json) def make_tool(tool: ToolInput) -> Tool: diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index 990849d62..d7fe0dc27 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -2,7 +2,7 @@ import uuid from dataclasses import dataclass from datetime import timedelta -from typing import Any, Optional, Union, no_type_check +from typing import Any, Optional, Union, cast, no_type_check import pytest from agents import ( @@ -96,6 +96,7 @@ def __init__( ) -> None: global response_index response_index = 0 + self.inputs: list[Union[str, list[TResponseInputItem]]] = [] super().__init__(model, openai_client) async def get_response( @@ -113,6 +114,7 @@ async def get_response( global response_index response = self.responses[response_index] response_index += 1 + self.inputs.append(input) return response @@ -843,11 +845,12 @@ async def test_agents_as_tools_workflow(client: Client, use_local_model: bool): model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=30)) with set_open_ai_agent_temporal_overrides(model_params): + model = AgentAsToolsModel( # type: ignore + "", openai_client=AsyncOpenAI(api_key="Fake key") + ) model_activity = ModelActivity( TestProvider( - AgentAsToolsModel( # type: ignore - "", openai_client=AsyncOpenAI(api_key="Fake key") - ) + model, ) if use_local_model else None @@ -900,6 +903,7 @@ async def test_agents_as_tools_workflow(client: Client, use_local_model: bool): .activity_task_completed_event_attributes.result.payloads[0] .data.decode() ) + assert isinstance(model.inputs[3][3]["content"], list) # type: ignore class AirlineAgentContext(BaseModel):