Skip to content

Commit 3a27fe8

Browse files
authored
Remove wrapping function decorator (#948)
* Remove wrapping function decorator * Linting
1 parent e17146f commit 3a27fe8

File tree

2 files changed

+83
-70
lines changed

2 files changed

+83
-70
lines changed

temporalio/contrib/openai_agents/temporal_openai_agents.py

Lines changed: 0 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -266,69 +266,3 @@ async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any:
266266
on_invoke_tool=run_activity,
267267
strict_json_schema=True,
268268
)
269-
270-
@classmethod
271-
@overload
272-
def tool(
273-
cls,
274-
*,
275-
name_override: Union[str, None] = None,
276-
description_override: Union[str, None] = None,
277-
docstring_style: Union[DocstringStyle, None] = None,
278-
use_docstring_info: bool = True,
279-
failure_error_function: Union[
280-
ToolErrorFunction, None
281-
] = default_tool_error_function,
282-
strict_mode: bool = True,
283-
is_enabled: Union[
284-
bool, Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]]
285-
] = True,
286-
) -> Callable[[ToolFunction[ToolParams]], FunctionTool]: ...
287-
288-
@classmethod
289-
@overload
290-
def tool(
291-
cls,
292-
func: ToolFunction[ToolParams],
293-
*,
294-
name_override: Union[str, None] = None,
295-
description_override: Union[str, None] = None,
296-
docstring_style: Union[DocstringStyle, None] = None,
297-
use_docstring_info: bool = True,
298-
failure_error_function: Union[
299-
ToolErrorFunction, None
300-
] = default_tool_error_function,
301-
strict_mode: bool = True,
302-
is_enabled: Union[
303-
bool, Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]]
304-
] = True,
305-
) -> FunctionTool: ...
306-
307-
@classmethod
308-
def tool(
309-
cls,
310-
func: Union[ToolFunction[ToolParams], None] = None,
311-
*,
312-
name_override: Union[str, None] = None,
313-
description_override: Union[str, None] = None,
314-
docstring_style: Union[DocstringStyle, None] = None,
315-
use_docstring_info: bool = True,
316-
failure_error_function: Union[
317-
ToolErrorFunction, None
318-
] = default_tool_error_function,
319-
strict_mode: bool = True,
320-
is_enabled: Union[
321-
bool, Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]]
322-
] = True,
323-
) -> Union[FunctionTool, Callable[[ToolFunction[ToolParams]], FunctionTool]]:
324-
"""A temporal specific wrapper for OpenAI's @function_tool. This exists to ensure the user is aware that the function tool is workflow level code and must be deterministic."""
325-
return function_tool(
326-
func, # type: ignore
327-
name_override=name_override,
328-
description_override=description_override,
329-
docstring_style=docstring_style,
330-
use_docstring_info=use_docstring_info,
331-
failure_error_function=failure_error_function,
332-
strict_mode=strict_mode,
333-
is_enabled=is_enabled,
334-
)

tests/contrib/openai_agents/test_openai.py

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
InputGuardrailTripwireTriggered,
1414
ItemHelpers,
1515
MessageOutputItem,
16-
Model,
17-
ModelProvider,
1816
ModelResponse,
1917
ModelSettings,
2018
ModelTracing,
@@ -25,6 +23,7 @@
2523
Tool,
2624
TResponseInputItem,
2725
Usage,
26+
function_tool,
2827
handoff,
2928
input_guardrail,
3029
output_guardrail,
@@ -779,7 +778,7 @@ class AirlineAgentContext(BaseModel):
779778
flight_number: Optional[str] = None
780779

781780

782-
@openai_agents.workflow.tool(
781+
@function_tool(
783782
name_override="faq_lookup_tool",
784783
description_override="Lookup frequently asked questions.",
785784
)
@@ -801,7 +800,7 @@ async def faq_lookup_tool(question: str) -> str:
801800
return "I'm sorry, I don't know the answer to that question."
802801

803802

804-
@openai_agents.workflow.tool
803+
@function_tool
805804
async def update_seat(
806805
context: RunContextWrapper[AirlineAgentContext],
807806
confirmation_number: str,
@@ -1503,3 +1502,83 @@ async def test_output_guardrail(client: Client, use_local_model: bool):
15031502

15041503
if use_local_model:
15051504
assert not result
1505+
1506+
1507+
class WorkflowToolModel(StaticTestModel):
1508+
responses = [
1509+
ModelResponse(
1510+
output=[
1511+
ResponseFunctionToolCall(
1512+
arguments="{}",
1513+
call_id="call",
1514+
name="run_tool",
1515+
type="function_call",
1516+
id="id",
1517+
status="completed",
1518+
)
1519+
],
1520+
usage=Usage(),
1521+
response_id=None,
1522+
),
1523+
ModelResponse(
1524+
output=[
1525+
ResponseOutputMessage(
1526+
id="",
1527+
content=[
1528+
ResponseOutputText(
1529+
text="",
1530+
annotations=[],
1531+
type="output_text",
1532+
)
1533+
],
1534+
role="assistant",
1535+
status="completed",
1536+
type="message",
1537+
)
1538+
],
1539+
usage=Usage(),
1540+
response_id=None,
1541+
),
1542+
]
1543+
1544+
1545+
@workflow.defn
1546+
class WorkflowToolWorkflow:
1547+
@workflow.run
1548+
async def run(self) -> None:
1549+
agent: Agent = Agent(
1550+
name="Assistant",
1551+
instructions="You are a helpful assistant.",
1552+
tools=[function_tool(self.run_tool)],
1553+
)
1554+
await Runner.run(
1555+
agent,
1556+
"My phone number is 650-123-4567. Where do you think I live?",
1557+
)
1558+
1559+
async def run_tool(self):
1560+
print("Tool ran with self:", self)
1561+
workflow.logger.info("Tool ran with self: %s", self)
1562+
return None
1563+
1564+
1565+
async def test_workflow_method_tools(client: Client):
1566+
new_config = client.config()
1567+
new_config["data_converter"] = pydantic_data_converter
1568+
client = Client(**new_config)
1569+
1570+
with set_open_ai_agent_temporal_overrides():
1571+
model_activity = ModelActivity(TestModelProvider(WorkflowToolModel()))
1572+
async with new_worker(
1573+
client,
1574+
WorkflowToolWorkflow,
1575+
activities=[model_activity.invoke_model_activity],
1576+
interceptors=[OpenAIAgentsTracingInterceptor()],
1577+
) as worker:
1578+
workflow_handle = await client.start_workflow(
1579+
WorkflowToolWorkflow.run,
1580+
id=f"workflow-tool-{uuid.uuid4()}",
1581+
task_queue=worker.task_queue,
1582+
execution_timeout=timedelta(seconds=10),
1583+
)
1584+
await workflow_handle.result()

0 commit comments

Comments
 (0)