Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion temporalio/contrib/openai_agents/_heartbeat_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ async def wrapper(*args, **kwargs):
if heartbeat_task:
heartbeat_task.cancel()
# Wait for heartbeat cancellation to complete
await heartbeat_task
try:
await heartbeat_task
except asyncio.CancelledError:
pass

return cast(F, wrapper)

Expand Down
81 changes: 80 additions & 1 deletion tests/contrib/openai_agents/test_openai.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import asyncio
import os
import uuid
from dataclasses import dataclass
from datetime import timedelta
from typing import Any, Optional, Union, no_type_check
from typing import Any, AsyncIterator, Optional, Union, no_type_check

import nexusrpc
import pytest
Expand All @@ -14,6 +15,7 @@
InputGuardrailTripwireTriggered,
ItemHelpers,
MessageOutputItem,
Model,
ModelResponse,
ModelSettings,
ModelTracing,
Expand All @@ -35,6 +37,7 @@
HandoffOutputItem,
ToolCallItem,
ToolCallOutputItem,
TResponseStreamEvent,
)
from openai import AsyncOpenAI, BaseModel
from openai.types.responses import (
Expand Down Expand Up @@ -1778,3 +1781,79 @@ async def test_workflow_method_tools(client: Client):
execution_timeout=timedelta(seconds=10),
)
await workflow_handle.result()


class WaitModel(Model):
async def get_response(
self,
system_instructions: Union[str, None],
input: Union[str, list[TResponseInputItem]],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: Union[AgentOutputSchemaBase, None],
handoffs: list[Handoff],
tracing: ModelTracing,
*,
previous_response_id: Union[str, None],
prompt: Union[ResponsePromptParam, None] = None,
) -> ModelResponse:
activity.logger.info("Waiting")
await asyncio.sleep(5.0)
activity.logger.info("Returning")
return ModelResponse(
output=[
ResponseOutputMessage(
id="",
content=[
ResponseOutputText(
text="test", annotations=[], type="output_text"
)
],
role="assistant",
status="completed",
type="message",
)
],
usage=Usage(),
response_id=None,
)

def stream_response(
self,
system_instructions: Optional[str],
input: Union[str, list[TResponseInputItem]],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: Optional[AgentOutputSchemaBase],
handoffs: list[Handoff],
tracing: ModelTracing,
*,
previous_response_id: Optional[str],
prompt: Optional[ResponsePromptParam],
) -> AsyncIterator[TResponseStreamEvent]:
raise NotImplementedError()


async def test_heartbeat(client: Client):
new_config = client.config()
new_config["data_converter"] = pydantic_data_converter
client = Client(**new_config)

with set_open_ai_agent_temporal_overrides(
model_params=ModelActivityParameters(heartbeat_timeout=timedelta(seconds=2))
):
model_activity = ModelActivity(TestModelProvider(WaitModel()))
async with new_worker(
client,
HelloWorldAgent,
activities=[model_activity.invoke_model_activity],
interceptors=[OpenAIAgentsTracingInterceptor()],
) as worker:
workflow_handle = await client.start_workflow(
HelloWorldAgent.run,
"Tell me about recursion in programming.",
id=f"workflow-tool-{uuid.uuid4()}",
task_queue=worker.task_queue,
execution_timeout=timedelta(seconds=10),
)
await workflow_handle.result()
Loading