|
6 | 6 | import openai # use the official client for correctness check
|
7 | 7 | import pytest
|
8 | 8 |
|
9 |
| - |
10 |
| -MODEL_NAME = "Qwen/Qwen3-0.6B" |
11 |
| - |
12 |
| -@pytest.mark.asyncio |
13 |
| -@pytest.mark.parametrize("model_name", [MODEL_NAME]) |
14 |
| -@pytest.mark.parametrize("tool_choice", ["auto", "required"]) |
15 |
| -async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, |
16 |
| - tool_choice: str): |
17 |
| - tools = [ |
18 |
| - { |
19 |
| - "type": "function", |
20 |
| - "function": { |
21 |
| - "name": "get_current_weather", |
22 |
| - "description": "Get the current weather in a given location", |
23 |
| - "parameters": { |
| 9 | +MODEL_NAME = "Qwen/Qwen3-1.7B" |
| 10 | +tools = [ |
| 11 | + { |
| 12 | + "type": "function", |
| 13 | + "name": "get_current_weather", |
| 14 | + "description": "Get the current weather in a given location", |
| 15 | + "parameters": { |
| 16 | + "type": "object", |
| 17 | + "properties": { |
| 18 | + "city": { |
| 19 | + "type": "string", |
| 20 | + "description": |
| 21 | + "The city to find the weather for, e.g. 'Vienna'", |
| 22 | + "default": "Vienna", |
| 23 | + }, |
| 24 | + "country": { |
| 25 | + "type": |
| 26 | + "string", |
| 27 | + "description": |
| 28 | + "The country that the city is in, e.g. 'Austria'", |
| 29 | + }, |
| 30 | + "unit": { |
| 31 | + "type": "string", |
| 32 | + "description": "The unit to fetch the temperature in", |
| 33 | + "enum": ["celsius", "fahrenheit"], |
| 34 | + }, |
| 35 | + "options": { |
| 36 | + "$ref": "#/$defs/WeatherOptions", |
| 37 | + "description": "Optional parameters for weather query", |
| 38 | + }, |
| 39 | + }, |
| 40 | + "required": ["country", "unit"], |
| 41 | + "$defs": { |
| 42 | + "WeatherOptions": { |
| 43 | + "title": "WeatherOptions", |
24 | 44 | "type": "object",
|
| 45 | + "additionalProperties": False, |
25 | 46 | "properties": {
|
26 |
| - "city": { |
27 |
| - "type": "string", |
28 |
| - "description": |
29 |
| - "The city to find the weather for, e.g. 'Vienna'", |
30 |
| - "default": "Vienna", |
31 |
| - }, |
32 |
| - "country": { |
33 |
| - "type": |
34 |
| - "string", |
35 |
| - "description": |
36 |
| - "The country that the city is in, e.g. 'Austria'", |
37 |
| - }, |
38 | 47 | "unit": {
|
39 | 48 | "type": "string",
|
40 |
| - "description": |
41 |
| - "The unit to fetch the temperature in", |
42 | 49 | "enum": ["celsius", "fahrenheit"],
|
| 50 | + "default": "celsius", |
| 51 | + "description": "Temperature unit", |
| 52 | + "title": "Temperature Unit", |
43 | 53 | },
|
44 |
| - "options": { |
45 |
| - "$ref": "#/$defs/WeatherOptions", |
| 54 | + "include_forecast": { |
| 55 | + "type": "boolean", |
| 56 | + "default": False, |
46 | 57 | "description":
|
47 |
| - "Optional parameters for weather query", |
| 58 | + "Whether to include a 24-hour forecast", |
| 59 | + "title": "Include Forecast", |
48 | 60 | },
|
49 |
| - }, |
50 |
| - "required": ["country", "unit"], |
51 |
| - "$defs": { |
52 |
| - "WeatherOptions": { |
53 |
| - "title": "WeatherOptions", |
54 |
| - "type": "object", |
55 |
| - "additionalProperties": False, |
56 |
| - "properties": { |
57 |
| - "unit": { |
58 |
| - "type": "string", |
59 |
| - "enum": ["celsius", "fahrenheit"], |
60 |
| - "default": "celsius", |
61 |
| - "description": "Temperature unit", |
62 |
| - "title": "Temperature Unit", |
63 |
| - }, |
64 |
| - "include_forecast": { |
65 |
| - "type": "boolean", |
66 |
| - "default": False, |
67 |
| - "description": |
68 |
| - "Whether to include a 24-hour forecast", |
69 |
| - "title": "Include Forecast", |
70 |
| - }, |
71 |
| - "language": { |
72 |
| - "type": "string", |
73 |
| - "default": "zh-CN", |
74 |
| - "description": "Language of the response", |
75 |
| - "title": "Language", |
76 |
| - "enum": ["zh-CN", "en-US", "ja-JP"], |
77 |
| - }, |
78 |
| - }, |
| 61 | + "language": { |
| 62 | + "type": "string", |
| 63 | + "default": "zh-CN", |
| 64 | + "description": "Language of the response", |
| 65 | + "title": "Language", |
| 66 | + "enum": ["zh-CN", "en-US", "ja-JP"], |
79 | 67 | },
|
80 | 68 | },
|
81 | 69 | },
|
82 | 70 | },
|
83 | 71 | },
|
84 |
| - { |
85 |
| - "type": "function", |
86 |
| - "function": { |
87 |
| - "name": "get_forecast", |
88 |
| - "description": "Get the weather forecast for a given location", |
89 |
| - "parameters": { |
90 |
| - "type": "object", |
91 |
| - "properties": { |
92 |
| - "city": { |
93 |
| - "type": "string", |
94 |
| - "description": |
95 |
| - "The city to get the forecast for, e.g. 'Vienna'", |
96 |
| - "default": "Vienna", |
97 |
| - }, |
98 |
| - "country": { |
99 |
| - "type": |
100 |
| - "string", |
101 |
| - "description": |
102 |
| - "The country that the city is in, e.g. 'Austria'", |
103 |
| - }, |
104 |
| - "days": { |
105 |
| - "type": |
106 |
| - "integer", |
107 |
| - "description": |
108 |
| - "Number of days to get the forecast for (1-7)", |
109 |
| - }, |
110 |
| - "unit": { |
111 |
| - "type": "string", |
112 |
| - "description": |
113 |
| - "The unit to fetch the temperature in", |
114 |
| - "enum": ["celsius", "fahrenheit"], |
115 |
| - }, |
116 |
| - }, |
117 |
| - "required": ["country", "days", "unit"], |
| 72 | + }, |
| 73 | + { |
| 74 | + "type": "function", |
| 75 | + "name": "get_forecast", |
| 76 | + "description": "Get the weather forecast for a given location", |
| 77 | + "parameters": { |
| 78 | + "type": "object", |
| 79 | + "properties": { |
| 80 | + "city": { |
| 81 | + "type": "string", |
| 82 | + "description": |
| 83 | + "The city to get the forecast for, e.g. 'Vienna'", |
| 84 | + "default": "Vienna", |
| 85 | + }, |
| 86 | + "country": { |
| 87 | + "type": |
| 88 | + "string", |
| 89 | + "description": |
| 90 | + "The country that the city is in, e.g. 'Austria'", |
| 91 | + }, |
| 92 | + "days": { |
| 93 | + "type": "integer", |
| 94 | + "description": |
| 95 | + "Number of days to get the forecast for (1-7)", |
| 96 | + }, |
| 97 | + "unit": { |
| 98 | + "type": "string", |
| 99 | + "description": "The unit to fetch the temperature in", |
| 100 | + "enum": ["celsius", "fahrenheit"], |
118 | 101 | },
|
119 | 102 | },
|
| 103 | + "required": ["country", "days", "unit"], |
120 | 104 | },
|
121 |
| - ] |
| 105 | + }, |
| 106 | +] |
| 107 | + |
122 | 108 |
|
| 109 | +@pytest.mark.asyncio |
| 110 | +@pytest.mark.parametrize("model_name", [MODEL_NAME]) |
| 111 | +@pytest.mark.parametrize("tool_choice", ["auto", "required"]) |
| 112 | +async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, |
| 113 | + tool_choice: str): |
123 | 114 | prompt = [{
|
124 |
| - "role": |
125 |
| - "user", |
| 115 | + "role": "user", |
126 | 116 | "content":
|
127 | 117 | "Can you tell me what the current weather is in Berlin and the "\
|
128 | 118 | "forecast for the next 5 days, in fahrenheit?",
|
129 | 119 | },]
|
130 |
| - response = client.responses.create( |
| 120 | + response = await client.responses.create( |
131 | 121 | model=model_name,
|
132 | 122 | input=prompt,
|
133 | 123 | tools=tools,
|
134 | 124 | tool_choice=tool_choice,
|
135 | 125 | )
|
136 |
| - |
| 126 | + |
137 | 127 | assert len(response.output) >= 1
|
138 | 128 | tool_call = response.output[0]
|
139 |
| - |
| 129 | + |
140 | 130 | assert tool_call.type == "function_call"
|
141 | 131 | assert json.loads(tool_call.arguments) is not None
|
142 | 132 |
|
| 133 | + |
143 | 134 | @pytest.mark.asyncio
|
144 |
| -async def test_named_tool_use(client: openai.AsyncOpenAI, sample_json_schema): |
145 |
| - pass |
| 135 | +async def test_named_tool_use(client: openai.AsyncOpenAI): |
| 136 | + |
| 137 | + def get_weather(latitude: float, longitude: float) -> str: |
| 138 | + """ |
| 139 | + Mock function to simulate getting weather data. |
| 140 | + In a real application, this would call an external weather API. |
| 141 | + """ |
| 142 | + return f"Current temperature at ({latitude}, {longitude}) is 20°C." |
| 143 | + |
| 144 | + tools = [{ |
| 145 | + "type": "function", |
| 146 | + "name": "get_weather", |
| 147 | + "description": |
| 148 | + "Get current temperature for provided coordinates in celsius.", |
| 149 | + "parameters": { |
| 150 | + "type": "object", |
| 151 | + "properties": { |
| 152 | + "latitude": { |
| 153 | + "type": "number" |
| 154 | + }, |
| 155 | + "longitude": { |
| 156 | + "type": "number" |
| 157 | + } |
| 158 | + }, |
| 159 | + "required": ["latitude", "longitude"], |
| 160 | + "additionalProperties": False |
| 161 | + }, |
| 162 | + "strict": True |
| 163 | + }] |
| 164 | + |
| 165 | + input_messages = [{ |
| 166 | + "role": "user", |
| 167 | + "content": "What's the weather like in Paris today?" |
| 168 | + }] |
| 169 | + |
| 170 | + response = await client.responses.create(model=MODEL_NAME, |
| 171 | + input=input_messages, |
| 172 | + tools=tools, |
| 173 | + tool_choice={ |
| 174 | + "type": "function", |
| 175 | + "name": "get_weather" |
| 176 | + }) |
| 177 | + assert len(response.output) == 1 |
| 178 | + tool_call = response.output[0] |
| 179 | + assert tool_call.type == "function_call" |
| 180 | + assert tool_call.name == "get_weather" |
| 181 | + args = json.loads(tool_call.arguments) |
| 182 | + assert args["latitude"] is not None |
| 183 | + assert args["longitude"] is not None |
| 184 | + # call the tool |
| 185 | + result = get_weather(args["latitude"], args["longitude"]) |
| 186 | + input_messages.append(tool_call) # append model's function call message |
| 187 | + input_messages.append({ # append result message |
| 188 | + "type": "function_call_output", |
| 189 | + "call_id": tool_call.call_id, |
| 190 | + "output": str(result) |
| 191 | + }) |
| 192 | + # create a new response with the tool call result |
| 193 | + response_2 = await client.responses.create(model=MODEL_NAME, |
| 194 | + input=input_messages) |
| 195 | + # check the output |
| 196 | + assert len(response_2.output_text) > 0 |
0 commit comments