Skip to content

Making summary generation more robust #987

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jul 24, 2025
59 changes: 34 additions & 25 deletions temporalio/contrib/openai_agents/_temporal_model_stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

logger = logging.getLogger(__name__)

from typing import Any, AsyncIterator, Sequence, Union, cast
from typing import Any, AsyncIterator, Union, cast

from agents import (
AgentOutputSchema,
Expand Down Expand Up @@ -54,7 +54,7 @@ def __init__(
async def get_response(
self,
system_instructions: Optional[str],
input: Union[str, list[TResponseInputItem], dict[str, str]],
input: Union[str, list[TResponseInputItem]],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: Optional[AgentOutputSchemaBase],
Expand All @@ -64,28 +64,6 @@ async def get_response(
previous_response_id: Optional[str],
prompt: Optional[ResponsePromptParam],
) -> ModelResponse:
def get_summary(
input: Union[str, list[TResponseInputItem], dict[str, str]],
) -> str:
### Activity summary shown in the UI
try:
max_size = 100
if isinstance(input, str):
return input[:max_size]
elif isinstance(input, list):
seq_input = cast(Sequence[Any], input)
last_item = seq_input[-1]
if isinstance(last_item, dict):
return last_item.get("content", "")[:max_size]
elif hasattr(last_item, "content"):
return str(getattr(last_item, "content"))[:max_size]
return str(last_item)[:max_size]
elif isinstance(input, dict):
return input.get("content", "")[:max_size]
except Exception as e:
logger.error(f"Error getting summary: {e}")
return ""

def make_tool_info(tool: Tool) -> ToolInput:
if isinstance(tool, (FileSearchTool, WebSearchTool)):
return tool
Expand Down Expand Up @@ -150,7 +128,7 @@ def make_tool_info(tool: Tool) -> ToolInput:
return await workflow.execute_activity_method(
ModelActivity.invoke_model_activity,
activity_input,
summary=self.model_params.summary_override or get_summary(input),
summary=self.model_params.summary_override or _extract_summary(input),
task_queue=self.model_params.task_queue,
schedule_to_close_timeout=self.model_params.schedule_to_close_timeout,
schedule_to_start_timeout=self.model_params.schedule_to_start_timeout,
Expand All @@ -176,3 +154,34 @@ def stream_response(
prompt: ResponsePromptParam | None,
) -> AsyncIterator[TResponseStreamEvent]:
raise NotImplementedError("Temporal model doesn't support streams yet")


def _extract_summary(input: Union[str, list[TResponseInputItem]]) -> str:
### Activity summary shown in the UI
try:
max_size = 100
if isinstance(input, str):
return input[:max_size]
elif isinstance(input, list):
# Find all message inputs, which are reasonably summarizable
messages: list[TResponseInputItem] = [
item for item in input if item.get("type", "message") == "message"
]
if not messages:
return ""

content: Any = messages[-1].get("content", "")

# In the case of multiple contents, take the last one
if isinstance(content, list):
if not content:
return ""
content = content[-1]

# Take the text field from the content if present
if isinstance(content, dict) and content.get("text") is not None:
content = content.get("text")
return str(content)[:max_size]
except Exception as e:
logger.error(f"Error getting summary: {e}")
return ""
2 changes: 1 addition & 1 deletion temporalio/contrib/openai_agents/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any:
cancellation_type=cancellation_type,
activity_id=activity_id,
versioning_intent=versioning_intent,
summary=summary,
summary=summary or schema.description,
priority=priority,
)
try:
Expand Down
44 changes: 42 additions & 2 deletions tests/contrib/openai_agents/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,16 @@
)
from openai import APIStatusError, AsyncOpenAI, BaseModel
from openai.types.responses import (
EasyInputMessageParam,
ResponseFunctionToolCall,
ResponseFunctionToolCallParam,
ResponseFunctionWebSearch,
ResponseInputTextParam,
ResponseOutputMessage,
ResponseOutputText,
)
from openai.types.responses.response_function_web_search import ActionSearch
from openai.types.responses.response_input_item_param import Message
from openai.types.responses.response_prompt_param import ResponsePromptParam
from pydantic import ConfigDict, Field, TypeAdapter

Expand All @@ -63,6 +67,7 @@
TestModel,
TestModelProvider,
)
from temporalio.contrib.openai_agents._temporal_model_stub import _extract_summary
from temporalio.contrib.pydantic import pydantic_data_converter
from temporalio.exceptions import ApplicationError, CancelledError
from temporalio.testing import WorkflowEnvironment
Expand Down Expand Up @@ -680,7 +685,8 @@ async def test_research_workflow(client: Client, use_local_model: bool):
new_config["plugins"] = [
openai_agents.OpenAIAgentsPlugin(
model_params=ModelActivityParameters(
start_to_close_timeout=timedelta(seconds=30)
start_to_close_timeout=timedelta(seconds=120),
schedule_to_close_timeout=timedelta(seconds=120),
),
model_provider=TestModelProvider(TestResearchModel())
if use_local_model
Expand Down Expand Up @@ -1687,7 +1693,7 @@ class WorkflowToolModel(StaticTestModel):
id="",
content=[
ResponseOutputText(
text="",
text="Workflow tool was used",
annotations=[],
type="output_text",
)
Expand Down Expand Up @@ -1938,3 +1944,37 @@ async def test_heartbeat(client: Client, env: WorkflowEnvironment):
execution_timeout=timedelta(seconds=5.0),
)
await workflow_handle.result()


def test_summary_extraction():
input: list[TResponseInputItem] = [
EasyInputMessageParam(
content="First message",
role="user",
)
]

assert _extract_summary(input) == "First message"

input.append(
Message(
content=[
ResponseInputTextParam(
text="Second message",
type="input_text",
)
],
role="user",
)
)
assert _extract_summary(input) == "Second message"

input.append(
ResponseFunctionToolCallParam(
arguments="",
call_id="",
name="",
type="function_call",
)
)
assert _extract_summary(input) == "Second message"