Skip to content

Commit 79f2900

Browse files
authored
Making summary generation more robust (#987)
* Making summary generation more robust, generally summarize as the last message type input * Remove assert * Remove unused imports * Extend timeout for test stability * Unit test summary, explicitly handle some edge cases * Small fix
1 parent a457184 commit 79f2900

File tree

3 files changed

+77
-28
lines changed

3 files changed

+77
-28
lines changed

temporalio/contrib/openai_agents/_temporal_model_stub.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
logger = logging.getLogger(__name__)
1010

11-
from typing import Any, AsyncIterator, Sequence, Union, cast
11+
from typing import Any, AsyncIterator, Union, cast
1212

1313
from agents import (
1414
AgentOutputSchema,
@@ -54,7 +54,7 @@ def __init__(
5454
async def get_response(
5555
self,
5656
system_instructions: Optional[str],
57-
input: Union[str, list[TResponseInputItem], dict[str, str]],
57+
input: Union[str, list[TResponseInputItem]],
5858
model_settings: ModelSettings,
5959
tools: list[Tool],
6060
output_schema: Optional[AgentOutputSchemaBase],
@@ -64,28 +64,6 @@ async def get_response(
6464
previous_response_id: Optional[str],
6565
prompt: Optional[ResponsePromptParam],
6666
) -> ModelResponse:
67-
def get_summary(
68-
input: Union[str, list[TResponseInputItem], dict[str, str]],
69-
) -> str:
70-
### Activity summary shown in the UI
71-
try:
72-
max_size = 100
73-
if isinstance(input, str):
74-
return input[:max_size]
75-
elif isinstance(input, list):
76-
seq_input = cast(Sequence[Any], input)
77-
last_item = seq_input[-1]
78-
if isinstance(last_item, dict):
79-
return last_item.get("content", "")[:max_size]
80-
elif hasattr(last_item, "content"):
81-
return str(getattr(last_item, "content"))[:max_size]
82-
return str(last_item)[:max_size]
83-
elif isinstance(input, dict):
84-
return input.get("content", "")[:max_size]
85-
except Exception as e:
86-
logger.error(f"Error getting summary: {e}")
87-
return ""
88-
8967
def make_tool_info(tool: Tool) -> ToolInput:
9068
if isinstance(tool, (FileSearchTool, WebSearchTool)):
9169
return tool
@@ -150,7 +128,7 @@ def make_tool_info(tool: Tool) -> ToolInput:
150128
return await workflow.execute_activity_method(
151129
ModelActivity.invoke_model_activity,
152130
activity_input,
153-
summary=self.model_params.summary_override or get_summary(input),
131+
summary=self.model_params.summary_override or _extract_summary(input),
154132
task_queue=self.model_params.task_queue,
155133
schedule_to_close_timeout=self.model_params.schedule_to_close_timeout,
156134
schedule_to_start_timeout=self.model_params.schedule_to_start_timeout,
@@ -176,3 +154,34 @@ def stream_response(
176154
prompt: ResponsePromptParam | None,
177155
) -> AsyncIterator[TResponseStreamEvent]:
178156
raise NotImplementedError("Temporal model doesn't support streams yet")
157+
158+
159+
def _extract_summary(input: Union[str, list[TResponseInputItem]]) -> str:
160+
### Activity summary shown in the UI
161+
try:
162+
max_size = 100
163+
if isinstance(input, str):
164+
return input[:max_size]
165+
elif isinstance(input, list):
166+
# Find all message inputs, which are reasonably summarizable
167+
messages: list[TResponseInputItem] = [
168+
item for item in input if item.get("type", "message") == "message"
169+
]
170+
if not messages:
171+
return ""
172+
173+
content: Any = messages[-1].get("content", "")
174+
175+
# In the case of multiple contents, take the last one
176+
if isinstance(content, list):
177+
if not content:
178+
return ""
179+
content = content[-1]
180+
181+
# Take the text field from the content if present
182+
if isinstance(content, dict) and content.get("text") is not None:
183+
content = content.get("text")
184+
return str(content)[:max_size]
185+
except Exception as e:
186+
logger.error(f"Error getting summary: {e}")
187+
return ""

temporalio/contrib/openai_agents/workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any:
134134
cancellation_type=cancellation_type,
135135
activity_id=activity_id,
136136
versioning_intent=versioning_intent,
137-
summary=summary,
137+
summary=summary or schema.description,
138138
priority=priority,
139139
)
140140
try:

tests/contrib/openai_agents/test_openai.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,16 @@
4444
)
4545
from openai import APIStatusError, AsyncOpenAI, BaseModel
4646
from openai.types.responses import (
47+
EasyInputMessageParam,
4748
ResponseFunctionToolCall,
49+
ResponseFunctionToolCallParam,
4850
ResponseFunctionWebSearch,
51+
ResponseInputTextParam,
4952
ResponseOutputMessage,
5053
ResponseOutputText,
5154
)
5255
from openai.types.responses.response_function_web_search import ActionSearch
56+
from openai.types.responses.response_input_item_param import Message
5357
from openai.types.responses.response_prompt_param import ResponsePromptParam
5458
from pydantic import ConfigDict, Field, TypeAdapter
5559

@@ -63,6 +67,7 @@
6367
TestModel,
6468
TestModelProvider,
6569
)
70+
from temporalio.contrib.openai_agents._temporal_model_stub import _extract_summary
6671
from temporalio.contrib.pydantic import pydantic_data_converter
6772
from temporalio.exceptions import ApplicationError, CancelledError
6873
from temporalio.testing import WorkflowEnvironment
@@ -680,7 +685,8 @@ async def test_research_workflow(client: Client, use_local_model: bool):
680685
new_config["plugins"] = [
681686
openai_agents.OpenAIAgentsPlugin(
682687
model_params=ModelActivityParameters(
683-
start_to_close_timeout=timedelta(seconds=30)
688+
start_to_close_timeout=timedelta(seconds=120),
689+
schedule_to_close_timeout=timedelta(seconds=120),
684690
),
685691
model_provider=TestModelProvider(TestResearchModel())
686692
if use_local_model
@@ -1687,7 +1693,7 @@ class WorkflowToolModel(StaticTestModel):
16871693
id="",
16881694
content=[
16891695
ResponseOutputText(
1690-
text="",
1696+
text="Workflow tool was used",
16911697
annotations=[],
16921698
type="output_text",
16931699
)
@@ -1938,3 +1944,37 @@ async def test_heartbeat(client: Client, env: WorkflowEnvironment):
19381944
execution_timeout=timedelta(seconds=5.0),
19391945
)
19401946
await workflow_handle.result()
1947+
1948+
1949+
def test_summary_extraction():
1950+
input: list[TResponseInputItem] = [
1951+
EasyInputMessageParam(
1952+
content="First message",
1953+
role="user",
1954+
)
1955+
]
1956+
1957+
assert _extract_summary(input) == "First message"
1958+
1959+
input.append(
1960+
Message(
1961+
content=[
1962+
ResponseInputTextParam(
1963+
text="Second message",
1964+
type="input_text",
1965+
)
1966+
],
1967+
role="user",
1968+
)
1969+
)
1970+
assert _extract_summary(input) == "Second message"
1971+
1972+
input.append(
1973+
ResponseFunctionToolCallParam(
1974+
arguments="",
1975+
call_id="",
1976+
name="",
1977+
type="function_call",
1978+
)
1979+
)
1980+
assert _extract_summary(input) == "Second message"

0 commit comments

Comments
 (0)