|
13 | 13 | BaseConversationalTask,
|
14 | 14 | BaseTextGenerationTask,
|
15 | 15 | TaskProviderHelper,
|
| 16 | + filter_none, |
16 | 17 | recursive_merge,
|
17 | 18 | )
|
18 | 19 | from huggingface_hub.inference._providers.black_forest_labs import BlackForestLabsTextToImageTask
|
@@ -1152,6 +1153,98 @@ def test_prepare_payload(self):
|
1152 | 1153 | "model": "test-provider-id",
|
1153 | 1154 | }
|
1154 | 1155 |
|
| 1156 | + @pytest.mark.parametrize( |
| 1157 | + "raw_messages, expected_messages", |
| 1158 | + [ |
| 1159 | + ( |
| 1160 | + [ |
| 1161 | + { |
| 1162 | + "role": "assistant", |
| 1163 | + "content": "", |
| 1164 | + "tool_calls": None, |
| 1165 | + } |
| 1166 | + ], |
| 1167 | + [ |
| 1168 | + { |
| 1169 | + "role": "assistant", |
| 1170 | + "content": "", |
| 1171 | + } |
| 1172 | + ], |
| 1173 | + ), |
| 1174 | + ( |
| 1175 | + [ |
| 1176 | + { |
| 1177 | + "role": "assistant", |
| 1178 | + "content": None, |
| 1179 | + "tool_calls": [ |
| 1180 | + { |
| 1181 | + "id": "call_1", |
| 1182 | + "type": "function", |
| 1183 | + "function": { |
| 1184 | + "name": "get_current_weather", |
| 1185 | + "arguments": '{"location": "San Francisco, CA", "unit": "celsius"}', |
| 1186 | + }, |
| 1187 | + }, |
| 1188 | + ], |
| 1189 | + }, |
| 1190 | + { |
| 1191 | + "role": "tool", |
| 1192 | + "content": "pong", |
| 1193 | + "tool_call_id": "abc123", |
| 1194 | + "name": "dummy_tool", |
| 1195 | + "tool_calls": None, |
| 1196 | + }, |
| 1197 | + ], |
| 1198 | + [ |
| 1199 | + { |
| 1200 | + "role": "assistant", |
| 1201 | + "tool_calls": [ |
| 1202 | + { |
| 1203 | + "id": "call_1", |
| 1204 | + "type": "function", |
| 1205 | + "function": { |
| 1206 | + "name": "get_current_weather", |
| 1207 | + "arguments": '{"location": "San Francisco, CA", "unit": "celsius"}', |
| 1208 | + }, |
| 1209 | + } |
| 1210 | + ], |
| 1211 | + }, |
| 1212 | + { |
| 1213 | + "role": "tool", |
| 1214 | + "content": "pong", |
| 1215 | + "tool_call_id": "abc123", |
| 1216 | + "name": "dummy_tool", |
| 1217 | + }, |
| 1218 | + ], |
| 1219 | + ), |
| 1220 | + ], |
| 1221 | + ) |
| 1222 | + def test_prepare_payload_filters_messages(self, raw_messages, expected_messages): |
| 1223 | + helper = BaseConversationalTask(provider="test-provider", base_url="https://api.test.com") |
| 1224 | + |
| 1225 | + parameters = { |
| 1226 | + "temperature": 0.2, |
| 1227 | + "max_tokens": None, |
| 1228 | + "top_p": None, |
| 1229 | + } |
| 1230 | + |
| 1231 | + payload = helper._prepare_payload_as_dict( |
| 1232 | + inputs=raw_messages, |
| 1233 | + parameters=parameters, |
| 1234 | + provider_mapping_info=InferenceProviderMapping( |
| 1235 | + provider="test-provider", |
| 1236 | + hf_model_id="test-model", |
| 1237 | + providerId="test-provider-id", |
| 1238 | + task="conversational", |
| 1239 | + status="live", |
| 1240 | + ), |
| 1241 | + ) |
| 1242 | + |
| 1243 | + assert payload["messages"] == expected_messages |
| 1244 | + assert payload["temperature"] == 0.2 |
| 1245 | + assert "max_tokens" not in payload |
| 1246 | + assert "top_p" not in payload |
| 1247 | + |
1155 | 1248 |
|
1156 | 1249 | class TestBaseTextGenerationTask:
|
1157 | 1250 | def test_prepare_route(self):
|
@@ -1236,6 +1329,36 @@ def test_recursive_merge(dict1: Dict, dict2: Dict, expected: Dict):
|
1236 | 1329 | assert dict2 == initial_dict2
|
1237 | 1330 |
|
1238 | 1331 |
|
| 1332 | +@pytest.mark.parametrize( |
| 1333 | + "data, expected", |
| 1334 | + [ |
| 1335 | + ({}, {}), # empty dictionary remains empty |
| 1336 | + ({"a": 1, "b": None, "c": 3}, {"a": 1, "c": 3}), # remove None at root level |
| 1337 | + ({"a": None, "b": {"x": None, "y": 2}}, {"b": {"y": 2}}), # remove nested None |
| 1338 | + ({"a": {"b": {"c": None}}}, {}), # remove empty nested dict |
| 1339 | + ( |
| 1340 | + {"a": "", "b": {"x": {"y": None}, "z": 0}, "c": []}, # do not remove 0, [] and "" values |
| 1341 | + {"a": "", "b": {"z": 0}, "c": []}, |
| 1342 | + ), |
| 1343 | + ( |
| 1344 | + {"a": [0, 1, None]}, # do not remove None in lists |
| 1345 | + {"a": [0, 1, None]}, |
| 1346 | + ), |
| 1347 | + # dicts inside list are cleaned, list level None kept |
| 1348 | + ({"a": [{"x": None, "y": 1}, None]}, {"a": [{"y": 1}, None]}), |
| 1349 | + # remove every None that is the value of a dict key |
| 1350 | + ( |
| 1351 | + [None, {"x": None, "y": 5}, [None, 6]], |
| 1352 | + [None, {"y": 5}, [None, 6]], |
| 1353 | + ), |
| 1354 | + ({"a": [None, {"x": None}]}, {"a": [None, {}]}), |
| 1355 | + ], |
| 1356 | +) |
| 1357 | +def test_filter_none(data: Dict, expected: Dict): |
| 1358 | + """Test that filter_none removes None values from nested dictionaries.""" |
| 1359 | + assert filter_none(data) == expected |
| 1360 | + |
| 1361 | + |
1239 | 1362 | def test_get_provider_helper_auto(mocker):
|
1240 | 1363 | """Test the 'auto' provider selection logic."""
|
1241 | 1364 |
|
|
0 commit comments