From 2c3b9cbce80b809d57f4ad3826c3f65dd5b82c91 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Wed, 4 Jun 2025 17:03:56 +0100 Subject: [PATCH 01/18] add hf inference providers support --- .../pydantic_ai/models/huggingface.py | 432 ++++++++++++++++++ .../pydantic_ai/providers/__init__.py | 4 + .../pydantic_ai/providers/huggingface.py | 72 +++ pydantic_ai_slim/pyproject.toml | 2 + pyproject.toml | 2 +- uv.lock | 34 +- 6 files changed, 539 insertions(+), 7 deletions(-) create mode 100644 pydantic_ai_slim/pydantic_ai/models/huggingface.py create mode 100644 pydantic_ai_slim/pydantic_ai/providers/huggingface.py diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py new file mode 100644 index 000000000..c34d741a3 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -0,0 +1,432 @@ +from __future__ import annotations as _annotations + +import base64 +from collections.abc import AsyncIterable, AsyncIterator +from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Literal, cast, overload + +from typing_extensions import assert_never + +from pydantic_ai.providers import Provider, infer_provider + +from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage +from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc +from ..messages import ( + AudioUrl, + BinaryContent, + DocumentUrl, + ImageUrl, + ModelMessage, + ModelRequest, + ModelResponse, + ModelResponsePart, + ModelResponseStreamEvent, + RetryPromptPart, + SystemPromptPart, + TextPart, + ToolCallPart, + ToolReturnPart, + UserPromptPart, + VideoUrl, +) +from ..settings import ModelSettings +from ..tools import ToolDefinition +from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests + +try: + import aiohttp + from huggingface_hub import ( + AsyncInferenceClient, + ChatCompletionInputMessage, + ChatCompletionInputMessageChunk, + ChatCompletionInputTool, + ChatCompletionInputToolCall, + ChatCompletionInputURL, + ChatCompletionOutput, + ChatCompletionOutputMessage, + ChatCompletionStreamOutput, + InferenceTimeoutError, + ) + +except ImportError as _import_error: + raise ImportError( + 'Please install `huggingface_hub` to use Hugging Face Inference Providers, ' + 'you can use the `huggingface` optional group — `pip install "pydantic-ai-slim[huggingface]"`' + ) from _import_error + +__all__ = ( + 'HuggingFaceModel', + 'HuggingFaceModelSettings', +) + + +HFSystemPromptRole = Literal['system', 'user'] + + +class HuggingFaceModelSettings(ModelSettings, total=False): + """Settings used for a Hugging Face model request. + + ALL FIELDS MUST BE `huggingface_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS. + """ + + # This class is a placeholder for any future huggingface-specific settings + + +@dataclass(init=False) +class HuggingFaceModel(Model): + """A model that uses Hugging Face Inference Providers. + + Internally, this uses the [HF Python client](https://github.com/huggingface/huggingface_hub) to interact with the API. + + Apart from `__init__`, all methods are private or match those of the base class. + """ + + client: AsyncInferenceClient = field(repr=False) + + _model_name: str = field(repr=False) + _system: str = field(default='huggingface', repr=False) + + def __init__( + self, + model_name: str, + *, + provider: Literal['huggingface'] | Provider[AsyncInferenceClient] = 'huggingface', + ): + """Initialize a Hugging Face model. + + Args: + model_name: The name of the Model to use. You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending). + provider: The provider to use for Hugging Face Inference Providers. Can be either the string 'huggingface' or an + instance of `Provider[AsyncInferenceClient]`. If not provided, the other parameters will be used. + """ + self._model_name = model_name + self._provider = provider + if isinstance(provider, str): + provider = infer_provider(provider) + self.client = provider.client + + async def request( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> ModelResponse: + check_allow_model_requests() + response = await self._completions_create( + messages, False, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters + ) + model_response = self._process_response(response) + model_response.usage.requests = 1 + return model_response + + @asynccontextmanager + async def request_stream( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> AsyncIterator[StreamedResponse]: + check_allow_model_requests() + response = await self._completions_create( + messages, True, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters + ) + yield await self._process_streamed_response(response) + + @property + def model_name(self) -> str: + """The model name.""" + return self._model_name + + @property + def system(self) -> str: + """The system / model provider.""" + return self._system + + @overload + async def _completions_create( + self, + messages: list[ModelMessage], + stream: Literal[True], + model_settings: HuggingFaceModelSettings, + model_request_parameters: ModelRequestParameters, + ) -> AsyncIterable[ChatCompletionStreamOutput]: ... + + @overload + async def _completions_create( + self, + messages: list[ModelMessage], + stream: Literal[False], + model_settings: HuggingFaceModelSettings, + model_request_parameters: ModelRequestParameters, + ) -> ChatCompletionOutput: ... + + async def _completions_create( + self, + messages: list[ModelMessage], + stream: bool, + model_settings: HuggingFaceModelSettings, + model_request_parameters: ModelRequestParameters, + ) -> ChatCompletionOutput | AsyncIterable[ChatCompletionStreamOutput]: + tools = self._get_tools(model_request_parameters) + + if not tools: + tool_choice: Literal['none', 'required', 'auto'] | None = None + elif not model_request_parameters.allow_text_output: + tool_choice = 'required' + else: + tool_choice = 'auto' + + hf_messages = await self._map_messages(messages) + + try: + return await self.client.chat.completions.create( # type: ignore + model=self._model_name, + messages=hf_messages, # type: ignore + tools=tools, + tool_choice=tool_choice or None, + stream=stream, + stop=model_settings.get('stop_sequences', None), + temperature=model_settings.get('temperature', None), + top_p=model_settings.get('top_p', None), + seed=model_settings.get('seed', None), + presence_penalty=model_settings.get('presence_penalty', None), + frequency_penalty=model_settings.get('frequency_penalty', None), + logit_bias=model_settings.get('logit_bias', None), # type: ignore + logprobs=model_settings.get('logprobs', None), + top_logprobs=model_settings.get('top_logprobs', None), + extra_body=model_settings.get('extra_body'), # type: ignore + ) + except (InferenceTimeoutError, aiohttp.ClientResponseError) as e: + if isinstance(e, aiohttp.ClientResponseError): + raise ModelHTTPError( + status_code=e.status, + model_name=self.model_name, + body=e.response_error_payload, # type: ignore + ) from e + raise # pragma: lax no cover + + def _process_response(self, response: ChatCompletionOutput) -> ModelResponse: + """Process a non-streamed response, and prepare a message to return.""" + if response.created: + timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc) + else: + timestamp = _now_utc() + + choice = response.choices[0] + items: list[ModelResponsePart] = [] + + if choice.message.content is not None: + items.append(TextPart(choice.message.content)) + if choice.message.tool_calls is not None: + for c in choice.message.tool_calls: + items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id)) + return ModelResponse( + items, + usage=_map_usage(response), + model_name=response.model, + timestamp=timestamp, + vendor_id=response.id, + ) + + async def _process_streamed_response(self, response: AsyncIterable[ChatCompletionStreamOutput]) -> StreamedResponse: + """Process a streamed response, and prepare a streaming response to return.""" + peekable_response = _utils.PeekableAsyncStream(response) + first_chunk = await peekable_response.peek() + if isinstance(first_chunk, _utils.Unset): + raise UnexpectedModelBehavior( # pragma: no cover + 'Streamed response ended without content or tool calls' + ) + + return HuggingFaceStreamedResponse( + _model_name=self._model_name, + _response=peekable_response, + _timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc), + ) + + def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ChatCompletionInputTool]: + tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] + if model_request_parameters.output_tools: + tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools] + return tools + + async def _map_messages( + self, messages: list[ModelMessage] + ) -> list[ChatCompletionInputMessage | ChatCompletionOutputMessage]: + """Just maps a `pydantic_ai.Message` to a `huggingface_hub.ChatCompletionInputMessage`.""" + hf_messages: list[ChatCompletionInputMessage | ChatCompletionOutputMessage] = [] + for message in messages: + if isinstance(message, ModelRequest): + async for item in self._map_user_message(message): + hf_messages.append(item) + elif isinstance(message, ModelResponse): + texts: list[str] = [] + tool_calls: list[ChatCompletionInputToolCall] = [] + for item in message.parts: + if isinstance(item, TextPart): + texts.append(item.content) + elif isinstance(item, ToolCallPart): + tool_calls.append(self._map_tool_call(item)) + else: + assert_never(item) + message_param = ChatCompletionInputMessage(role='assistant') # type: ignore + if texts: + # Note: model responses from this model should only have one text item, so the following + # shouldn't merge multiple texts into one unless you switch models between runs: + message_param['content'] = '\n\n'.join(texts) + if tool_calls: + message_param['tool_calls'] = tool_calls + hf_messages.append(message_param) + else: + assert_never(message) + if instructions := self._get_instructions(messages): + hf_messages.insert(0, ChatCompletionInputMessage(content=instructions, role='system')) # type: ignore + return hf_messages + + @staticmethod + def _map_tool_call(t: ToolCallPart) -> ChatCompletionInputToolCall: + return ChatCompletionInputToolCall.parse_obj_as_instance( # type: ignore + { + 'id': _guard_tool_call_id(t=t), + 'type': 'function', + 'function': { + 'name': t.tool_name, + 'arguments': t.args_as_json_str(), + }, + } + ) + + @staticmethod + def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool: + tool_param: ChatCompletionInputTool = ChatCompletionInputTool.parse_obj_as_instance( # type: ignore + { + 'type': 'function', + 'function': { + 'name': f.name, + 'description': f.description, + 'parameters': f.parameters_json_schema, + }, + } + ) + if f.strict: + tool_param['function']['strict'] = f.strict + return tool_param + + async def _map_user_message( + self, message: ModelRequest + ) -> AsyncIterable[ChatCompletionInputMessage | ChatCompletionOutputMessage]: + for part in message.parts: + if isinstance(part, SystemPromptPart): + yield ChatCompletionInputMessage.parse_obj_as_instance({'role': 'system', 'content': part.content}) # type: ignore + elif isinstance(part, UserPromptPart): + yield await self._map_user_prompt(part) + elif isinstance(part, ToolReturnPart): + yield ChatCompletionOutputMessage.parse_obj_as_instance( # type: ignore + { + 'role': 'tool', + 'tool_call_id': _guard_tool_call_id(t=part), + 'content': part.model_response_str(), + } + ) + elif isinstance(part, RetryPromptPart): + if part.tool_name is None: + yield ChatCompletionInputMessage.parse_obj_as_instance( # type: ignore + {'role': 'user', 'content': part.model_response()} + ) + else: + yield ChatCompletionInputMessage.parse_obj_as_instance( # type: ignore + { + 'role': 'tool', + 'tool_call_id': _guard_tool_call_id(t=part), + 'content': part.model_response(), + } + ) + else: + assert_never(part) + + @staticmethod + async def _map_user_prompt(part: UserPromptPart) -> ChatCompletionInputMessage: + content: str | list[ChatCompletionInputMessage] + if isinstance(part.content, str): + content = part.content + else: + content = [] + for item in part.content: + if isinstance(item, str): + content.append(ChatCompletionInputMessageChunk(type='text', text=item)) # type: ignore + elif isinstance(item, ImageUrl): + url = ChatCompletionInputURL(url=item.url) # type: ignore + content.append(ChatCompletionInputMessageChunk(type='image_url', image_url=url)) # type: ignore + elif isinstance(item, BinaryContent): + base64_encoded = base64.b64encode(item.data).decode('utf-8') + if item.is_image: + url = ChatCompletionInputURL(url=f'data:{item.media_type};base64,{base64_encoded}') # type: ignore + content.append(ChatCompletionInputMessageChunk(type='image_url', image_url=url)) # type: ignore + else: # pragma: no cover + raise RuntimeError(f'Unsupported binary content type: {item.media_type}') + elif isinstance(item, AudioUrl): + raise NotImplementedError('AudioUrl is not supported for Hugging Face') + elif isinstance(item, DocumentUrl): + raise NotImplementedError('DocumentUrl is not supported for Hugging Face') + elif isinstance(item, VideoUrl): # pragma: no cover + raise NotImplementedError('VideoUrl is not supported for Hugging Face') + else: + assert_never(item) + return ChatCompletionInputMessage(role='user', content=content) # type: ignore + + +@dataclass +class HuggingFaceStreamedResponse(StreamedResponse): + """Implementation of `StreamedResponse` for Hugging Face models.""" + + _model_name: str + _response: AsyncIterable[ChatCompletionStreamOutput] + _timestamp: datetime + + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: + async for chunk in self._response: + self._usage += _map_usage(chunk) + + try: + choice = chunk.choices[0] + except IndexError: + continue + + # Handle the text part of the response + content = choice.delta.content + if content is not None: + yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content) + + for dtc in choice.delta.tool_calls or []: + maybe_event = self._parts_manager.handle_tool_call_delta( + vendor_part_id=dtc.index, + tool_name=dtc.function.name, + args=dtc.function.arguments, + tool_call_id=dtc.id, + ) + if maybe_event is not None: + yield maybe_event + + @property + def model_name(self) -> str: + """Get the model name of the response.""" + return self._model_name + + @property + def timestamp(self) -> datetime: + """Get the timestamp of the response.""" + return self._timestamp + + +def _map_usage(response: ChatCompletionOutput | ChatCompletionStreamOutput) -> usage.Usage: + response_usage = response.usage + if response_usage is None: + return usage.Usage() + + return usage.Usage( + request_tokens=response_usage.prompt_tokens, + response_tokens=response_usage.completion_tokens, + total_tokens=response_usage.total_tokens, + details=None, + ) diff --git a/pydantic_ai_slim/pydantic_ai/providers/__init__.py b/pydantic_ai_slim/pydantic_ai/providers/__init__.py index 379bbbc5d..dedeb6d8f 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/providers/__init__.py @@ -89,5 +89,9 @@ def infer_provider(provider: str) -> Provider[Any]: from .cohere import CohereProvider return CohereProvider() + elif provider == 'huggingface': + from .huggingface import HuggingFaceProvider + + return HuggingFaceProvider() else: # pragma: no cover raise ValueError(f'Unknown provider: {provider}') diff --git a/pydantic_ai_slim/pydantic_ai/providers/huggingface.py b/pydantic_ai_slim/pydantic_ai/providers/huggingface.py new file mode 100644 index 000000000..182a1bc83 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/providers/huggingface.py @@ -0,0 +1,72 @@ +from __future__ import annotations as _annotations + +import os + +from mistralai import httpx + +try: + from huggingface_hub import AsyncInferenceClient +except ImportError as _import_error: # pragma: no cover + raise ImportError( + 'Please install the `huggingface_hub` package to use the HuggingFace provider, ' + "you can use the `huggingface` optional group — `pip install 'pydantic-ai-slim[huggingface]'`" + ) from _import_error + +from . import Provider + + +class HuggingFaceProvider(Provider[AsyncInferenceClient]): + """Provider for HuggingFace API.""" + + @property + def name(self) -> str: + return 'huggingface' + + @property + def base_url(self) -> str: + return self.client.model # type: ignore + + @property + def client(self) -> AsyncInferenceClient: + return self._client + + def __init__( + self, + base_url: str | None = None, + api_key: str | None = None, + hf_client: AsyncInferenceClient | None = None, + http_client: httpx.AsyncClient | None = None, + provider: str | None = None, + ) -> None: + """Create a new Hugging Face provider. + + Args: + base_url: The base url for the Hugging Face requests. If not provided, it will default to the HF Inference API base url. + api_key: The API key to use for authentication, if not provided, the `HF_TOKEN` environment variable + will be used if available. + hf_client: An existing + [`AsyncInferenceClient`](https://huggingface.co/docs/huggingface_hub/v0.29.3/en/package_reference/inference_client#huggingface_hub.AsyncInferenceClient) + client to use. If not provided, a new instance will be created. + http_client: (currently ignored) An existing `httpx.AsyncClient` to use for making HTTP requests. + provider : Name of the provider to use for inference. available providers can be found in the [HF Inference Providers documentation](https://huggingface.co/docs/inference-providers/index#partners). + defaults to "auto", which will select the first available provider for the model, the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers. + If `base_url` is passed, then `provider` is not used. + """ + api_key = api_key or os.environ.get('HF_TOKEN') + + if api_key is None: + raise ValueError( + 'Set the `HF_TOKEN` environment variable or pass it via `HuggingFaceProvider(api_key=...)`' + 'to use the HuggingFace provider.' + ) + + if http_client is not None: + raise ValueError('`http_client` is ignored for HuggingFace provider, please use `hf_client` instead') + + if base_url is not None and provider is not None: + raise ValueError('Cannot provide both `base_url` and `provider`') + + if hf_client is None: + self._client = AsyncInferenceClient(api_key=api_key, provider=provider, base_url=base_url) # type: ignore + else: + self._client = hf_client diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 631cc196d..6188867c2 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -69,6 +69,7 @@ anthropic = ["anthropic>=0.52.0"] groq = ["groq>=0.15.0"] mistral = ["mistralai>=1.2.5"] bedrock = ["boto3>=1.35.74"] +huggingface = ["huggingface-hub>=0.32.0", "aiohttp"] # Tools duckduckgo = ["duckduckgo-search>=7.0.0"] tavily = ["tavily-python>=0.5.0"] @@ -81,6 +82,7 @@ evals = ["pydantic-evals=={{ version }}"] # A2A a2a = ["fasta2a=={{ version }}"] + [dependency-groups] dev = [ "anyio>=4.5.0", diff --git a/pyproject.toml b/pyproject.toml index 04b206432..74404b1f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ requires-python = ">=3.9" [tool.hatch.metadata.hooks.uv-dynamic-versioning] dependencies = [ - "pydantic-ai-slim[openai,vertexai,google,groq,anthropic,mistral,cohere,bedrock,cli,mcp,evals,a2a]=={{ version }}", + "pydantic-ai-slim[openai,vertexai,google,groq,anthropic,mistral,cohere,bedrock,huggingface,cli,mcp,evals,a2a]=={{ version }}", ] [tool.hatch.metadata.hooks.uv-dynamic-versioning.optional-dependencies] diff --git a/uv.lock b/uv.lock index 1da96b32e..3c503af2f 100644 --- a/uv.lock +++ b/uv.lock @@ -1312,6 +1312,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/95/04/ff642e65ad6b90db43e668d70ffb6736436c7ce41fcc549f4e9472234127/h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761", size = 58259, upload-time = "2022-09-25T15:39:59.68Z" }, ] +[[package]] +name = "hf-xet" +version = "1.1.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/75/dc/dc091aeeb671e71cbec30e84963f9c0202c17337b24b0a800e7d205543e8/hf_xet-1.1.3.tar.gz", hash = "sha256:a5f09b1dd24e6ff6bcedb4b0ddab2d81824098bb002cf8b4ffa780545fa348c3", size = 488127, upload-time = "2025-06-04T00:47:27.456Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9b/1f/bc01a4c0894973adebbcd4aa338a06815c76333ebb3921d94dcbd40dae6a/hf_xet-1.1.3-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:c3b508b5f583a75641aebf732853deb058953370ce8184f5dabc49f803b0819b", size = 2256929, upload-time = "2025-06-04T00:47:21.206Z" }, + { url = "https://files.pythonhosted.org/packages/78/07/6ef50851b5c6b45b77a6e018fa299c69a2db3b8bbd0d5af594c0238b1ceb/hf_xet-1.1.3-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:b788a61977fbe6b5186e66239e2a329a3f0b7e7ff50dad38984c0c74f44aeca1", size = 2153719, upload-time = "2025-06-04T00:47:19.302Z" }, + { url = "https://files.pythonhosted.org/packages/52/48/e929e6e3db6e4758c2adf0f2ca2c59287f1b76229d8bdc1a4c9cfc05212e/hf_xet-1.1.3-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd2da210856444a34aad8ada2fc12f70dabed7cc20f37e90754d1d9b43bc0534", size = 4820519, upload-time = "2025-06-04T00:47:17.244Z" }, + { url = "https://files.pythonhosted.org/packages/28/2e/03f89c5014a5aafaa9b150655f811798a317036646623bdaace25f485ae8/hf_xet-1.1.3-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:8203f52827e3df65981984936654a5b390566336956f65765a8aa58c362bb841", size = 4964121, upload-time = "2025-06-04T00:47:15.17Z" }, + { url = "https://files.pythonhosted.org/packages/47/8b/5cd399a92b47d98086f55fc72d69bc9ea5e5c6f27a9ed3e0cdd6be4e58a3/hf_xet-1.1.3-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:30c575a5306f8e6fda37edb866762140a435037365eba7a17ce7bd0bc0216a8b", size = 5283017, upload-time = "2025-06-04T00:47:23.239Z" }, + { url = "https://files.pythonhosted.org/packages/53/e3/2fcec58d2fcfd25ff07feb876f466cfa11f8dcf9d3b742c07fe9dd51ee0a/hf_xet-1.1.3-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:7c1a6aa6abed1f696f8099aa9796ca04c9ee778a58728a115607de9cc4638ff1", size = 4970349, upload-time = "2025-06-04T00:47:25.383Z" }, + { url = "https://files.pythonhosted.org/packages/53/bf/10ca917e335861101017ff46044c90e517b574fbb37219347b83be1952f6/hf_xet-1.1.3-cp37-abi3-win_amd64.whl", hash = "sha256:b578ae5ac9c056296bb0df9d018e597c8dc6390c5266f35b5c44696003cde9f3", size = 2310934, upload-time = "2025-06-04T00:47:29.632Z" }, +] + [[package]] name = "httpcore" version = "1.0.7" @@ -1351,20 +1366,21 @@ wheels = [ [[package]] name = "huggingface-hub" -version = "0.29.1" +version = "0.32.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, { name = "fsspec" }, + { name = "hf-xet", marker = "platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'" }, { name = "packaging" }, { name = "pyyaml" }, { name = "requests" }, { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/22/37/797d6476f13e5ef6af5fc48a5d641d32b39c37e166ccf40c3714c5854a85/huggingface_hub-0.29.1.tar.gz", hash = "sha256:9524eae42077b8ff4fc459ceb7a514eca1c1232b775276b009709fe2a084f250", size = 389776, upload-time = "2025-02-20T09:24:59.839Z" } +sdist = { url = "https://files.pythonhosted.org/packages/60/c8/4f7d270285c46324fd66f62159eb16739aa5696f422dba57678a8c6b78e9/huggingface_hub-0.32.4.tar.gz", hash = "sha256:f61d45cd338736f59fb0e97550b74c24ee771bcc92c05ae0766b9116abe720be", size = 424494, upload-time = "2025-06-03T09:59:46.105Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ae/05/75b90de9093de0aadafc868bb2fa7c57651fd8f45384adf39bd77f63980d/huggingface_hub-0.29.1-py3-none-any.whl", hash = "sha256:352f69caf16566c7b6de84b54a822f6238e17ddd8ae3da4f8f2272aea5b198d5", size = 468049, upload-time = "2025-02-20T09:24:57.962Z" }, + { url = "https://files.pythonhosted.org/packages/67/8b/222140f3cfb6f17b0dd8c4b9a0b36bd4ebefe9fb0098ba35d6960abcda0f/huggingface_hub-0.32.4-py3-none-any.whl", hash = "sha256:37abf8826b38d971f60d3625229221c36e53fe58060286db9baf619cfbf39767", size = 512101, upload-time = "2025-06-03T09:59:44.099Z" }, ] [[package]] @@ -2896,7 +2912,7 @@ wheels = [ name = "pydantic-ai" source = { editable = "." } dependencies = [ - { name = "pydantic-ai-slim", extra = ["a2a", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "mcp", "mistral", "openai", "vertexai"] }, + { name = "pydantic-ai-slim", extra = ["a2a", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "huggingface", "mcp", "mistral", "openai", "vertexai"] }, ] [package.optional-dependencies] @@ -2930,7 +2946,7 @@ lint = [ requires-dist = [ { name = "logfire", marker = "extra == 'logfire'", specifier = ">=3.11.0" }, { name = "pydantic-ai-examples", marker = "extra == 'examples'", editable = "examples" }, - { name = "pydantic-ai-slim", extras = ["a2a", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "mcp", "mistral", "openai", "vertexai"], editable = "pydantic_ai_slim" }, + { name = "pydantic-ai-slim", extras = ["a2a", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "huggingface", "mcp", "mistral", "openai", "vertexai"], editable = "pydantic_ai_slim" }, ] provides-extras = ["examples", "logfire"] @@ -3029,6 +3045,10 @@ google = [ groq = [ { name = "groq" }, ] +huggingface = [ + { name = "aiohttp" }, + { name = "huggingface-hub" }, +] logfire = [ { name = "logfire" }, ] @@ -3071,6 +3091,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "aiohttp", marker = "extra == 'huggingface'" }, { name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.52.0" }, { name = "argcomplete", marker = "extra == 'cli'", specifier = ">=3.5.0" }, { name = "boto3", marker = "extra == 'bedrock'", specifier = ">=1.35.74" }, @@ -3084,6 +3105,7 @@ requires-dist = [ { name = "griffe", specifier = ">=1.3.2" }, { name = "groq", marker = "extra == 'groq'", specifier = ">=0.15.0" }, { name = "httpx", specifier = ">=0.27" }, + { name = "huggingface-hub", marker = "extra == 'huggingface'", specifier = ">=0.32.0" }, { name = "logfire", marker = "extra == 'logfire'", specifier = ">=3.11.0" }, { name = "mcp", marker = "python_full_version >= '3.10' and extra == 'mcp'", specifier = ">=1.9.0" }, { name = "mistralai", marker = "extra == 'mistral'", specifier = ">=1.2.5" }, @@ -3098,7 +3120,7 @@ requires-dist = [ { name = "tavily-python", marker = "extra == 'tavily'", specifier = ">=0.5.0" }, { name = "typing-inspection", specifier = ">=0.4.0" }, ] -provides-extras = ["a2a", "anthropic", "bedrock", "cli", "cohere", "duckduckgo", "evals", "google", "groq", "logfire", "mcp", "mistral", "openai", "tavily", "vertexai"] +provides-extras = ["a2a", "anthropic", "bedrock", "cli", "cohere", "duckduckgo", "evals", "google", "groq", "huggingface", "logfire", "mcp", "mistral", "openai", "tavily", "vertexai"] [package.metadata.requires-dev] dev = [ From 537a657931bd668ee07363a401338f2a7408eb6b Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Wed, 4 Jun 2025 17:22:45 +0100 Subject: [PATCH 02/18] update dependencies --- pydantic_ai_slim/pyproject.toml | 2 +- uv.lock | 11 +++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 6188867c2..bdafe2df9 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -69,7 +69,7 @@ anthropic = ["anthropic>=0.52.0"] groq = ["groq>=0.15.0"] mistral = ["mistralai>=1.2.5"] bedrock = ["boto3>=1.35.74"] -huggingface = ["huggingface-hub>=0.32.0", "aiohttp"] +huggingface = ["huggingface-hub[inference]>=0.32.0"] # Tools duckduckgo = ["duckduckgo-search>=7.0.0"] tavily = ["tavily-python>=0.5.0"] diff --git a/uv.lock b/uv.lock index 3c503af2f..20af5addd 100644 --- a/uv.lock +++ b/uv.lock @@ -1383,6 +1383,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/67/8b/222140f3cfb6f17b0dd8c4b9a0b36bd4ebefe9fb0098ba35d6960abcda0f/huggingface_hub-0.32.4-py3-none-any.whl", hash = "sha256:37abf8826b38d971f60d3625229221c36e53fe58060286db9baf619cfbf39767", size = 512101, upload-time = "2025-06-03T09:59:44.099Z" }, ] +[package.optional-dependencies] +inference = [ + { name = "aiohttp" }, +] + [[package]] name = "idna" version = "3.10" @@ -3046,8 +3051,7 @@ groq = [ { name = "groq" }, ] huggingface = [ - { name = "aiohttp" }, - { name = "huggingface-hub" }, + { name = "huggingface-hub", extra = ["inference"] }, ] logfire = [ { name = "logfire" }, @@ -3091,7 +3095,6 @@ dev = [ [package.metadata] requires-dist = [ - { name = "aiohttp", marker = "extra == 'huggingface'" }, { name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.52.0" }, { name = "argcomplete", marker = "extra == 'cli'", specifier = ">=3.5.0" }, { name = "boto3", marker = "extra == 'bedrock'", specifier = ">=1.35.74" }, @@ -3105,7 +3108,7 @@ requires-dist = [ { name = "griffe", specifier = ">=1.3.2" }, { name = "groq", marker = "extra == 'groq'", specifier = ">=0.15.0" }, { name = "httpx", specifier = ">=0.27" }, - { name = "huggingface-hub", marker = "extra == 'huggingface'", specifier = ">=0.32.0" }, + { name = "huggingface-hub", extras = ["inference"], marker = "extra == 'huggingface'", specifier = ">=0.32.0" }, { name = "logfire", marker = "extra == 'logfire'", specifier = ">=3.11.0" }, { name = "mcp", marker = "python_full_version >= '3.10' and extra == 'mcp'", specifier = ">=1.9.0" }, { name = "mistralai", marker = "extra == 'mistral'", specifier = ">=1.2.5" }, From af602a509040e2496e81924b3c038c8c254c16e9 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Wed, 4 Jun 2025 17:23:38 +0100 Subject: [PATCH 03/18] nit --- pydantic_ai_slim/pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index bdafe2df9..e0fd6c6f3 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -82,7 +82,6 @@ evals = ["pydantic-evals=={{ version }}"] # A2A a2a = ["fasta2a=={{ version }}"] - [dependency-groups] dev = [ "anyio>=4.5.0", From 1f3f7a21d0b315498564542d7f5867a5d3a4283b Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Wed, 4 Jun 2025 17:24:34 +0100 Subject: [PATCH 04/18] update docstring --- pydantic_ai_slim/pydantic_ai/providers/huggingface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/providers/huggingface.py b/pydantic_ai_slim/pydantic_ai/providers/huggingface.py index 182a1bc83..3d301340f 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/providers/huggingface.py @@ -16,7 +16,7 @@ class HuggingFaceProvider(Provider[AsyncInferenceClient]): - """Provider for HuggingFace API.""" + """Provider for Hugging Face.""" @property def name(self) -> str: From bea050c69abbb2abfb51ddc4338b088f137ec346 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Mon, 9 Jun 2025 13:16:57 +0100 Subject: [PATCH 05/18] add tests --- .../pydantic_ai/models/huggingface.py | 13 +- .../pydantic_ai/providers/huggingface.py | 10 +- tests/conftest.py | 14 + .../test_hf_model_instructions.yaml | 125 ++++ .../test_request_simple_success_with_vcr.yaml | 126 ++++ tests/models/test_huggingface.py | 692 ++++++++++++++++++ tests/providers/test_huggingface.py | 61 ++ 7 files changed, 1034 insertions(+), 7 deletions(-) create mode 100644 tests/models/cassettes/test_huggingface/test_hf_model_instructions.yaml create mode 100644 tests/models/cassettes/test_huggingface/test_request_simple_success_with_vcr.yaml create mode 100644 tests/models/test_huggingface.py create mode 100644 tests/providers/test_huggingface.py diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index c34d741a3..1dc1db73c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -49,6 +49,7 @@ ChatCompletionStreamOutput, InferenceTimeoutError, ) + from huggingface_hub.errors import HfHubHTTPError except ImportError as _import_error: raise ImportError( @@ -198,13 +199,19 @@ async def _completions_create( top_logprobs=model_settings.get('top_logprobs', None), extra_body=model_settings.get('extra_body'), # type: ignore ) - except (InferenceTimeoutError, aiohttp.ClientResponseError) as e: + except (InferenceTimeoutError, aiohttp.ClientResponseError, HfHubHTTPError) as e: if isinstance(e, aiohttp.ClientResponseError): raise ModelHTTPError( status_code=e.status, model_name=self.model_name, body=e.response_error_payload, # type: ignore ) from e + elif isinstance(e, HfHubHTTPError): + raise ModelHTTPError( + status_code=e.response.status_code, + model_name=self.model_name, + body=e.response.content, + ) from e raise # pragma: lax no cover def _process_response(self, response: ChatCompletionOutput) -> ModelResponse: @@ -401,8 +408,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: for dtc in choice.delta.tool_calls or []: maybe_event = self._parts_manager.handle_tool_call_delta( vendor_part_id=dtc.index, - tool_name=dtc.function.name, - args=dtc.function.arguments, + tool_name=dtc.function and dtc.function.name, # type: ignore + args=dtc.function and dtc.function.arguments, tool_call_id=dtc.id, ) if maybe_event is not None: diff --git a/pydantic_ai_slim/pydantic_ai/providers/huggingface.py b/pydantic_ai_slim/pydantic_ai/providers/huggingface.py index 3d301340f..e18a60d16 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/providers/huggingface.py @@ -2,7 +2,9 @@ import os -from mistralai import httpx +from httpx import AsyncClient + +from pydantic_ai.exceptions import UserError try: from huggingface_hub import AsyncInferenceClient @@ -35,13 +37,13 @@ def __init__( base_url: str | None = None, api_key: str | None = None, hf_client: AsyncInferenceClient | None = None, - http_client: httpx.AsyncClient | None = None, + http_client: AsyncClient | None = None, provider: str | None = None, ) -> None: """Create a new Hugging Face provider. Args: - base_url: The base url for the Hugging Face requests. If not provided, it will default to the HF Inference API base url. + base_url: The base url for the Hugging Face requests. api_key: The API key to use for authentication, if not provided, the `HF_TOKEN` environment variable will be used if available. hf_client: An existing @@ -55,7 +57,7 @@ def __init__( api_key = api_key or os.environ.get('HF_TOKEN') if api_key is None: - raise ValueError( + raise UserError( 'Set the `HF_TOKEN` environment variable or pass it via `HuggingFaceProvider(api_key=...)`' 'to use the HuggingFace provider.' ) diff --git a/tests/conftest.py b/tests/conftest.py index 65c104718..db8459fc8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -281,6 +281,11 @@ def openrouter_api_key() -> str: return os.getenv('OPENROUTER_API_KEY', 'mock-api-key') +@pytest.fixture(scope='session') +def huggingface_api_key() -> str: + return os.getenv('HF_TOKEN', 'mock-api-key') or os.getenv('HUGGINGFACE_API_KEY', 'mock-api-key') + + @pytest.fixture(scope='session') def bedrock_provider(): try: @@ -309,6 +314,7 @@ def model( groq_api_key: str, co_api_key: str, gemini_api_key: str, + huggingface_api_key: str, bedrock_provider: BedrockProvider, ) -> Model: # pragma: lax no cover try: @@ -346,6 +352,14 @@ def model( from pydantic_ai.models.bedrock import BedrockConverseModel return BedrockConverseModel('us.amazon.nova-micro-v1:0', provider=bedrock_provider) + elif request.param == 'huggingface': + from pydantic_ai.models.huggingface import HuggingFaceModel + from pydantic_ai.providers.huggingface import HuggingFaceProvider + + return HuggingFaceModel( + 'Qwen/Qwen2.5-72B-Instruct', + provider=HuggingFaceProvider(provider='nebius', api_key=huggingface_api_key), + ) else: raise ValueError(f'Unknown model: {request.param}') except ImportError: diff --git a/tests/models/cassettes/test_huggingface/test_hf_model_instructions.yaml b/tests/models/cassettes/test_huggingface/test_hf_model_instructions.yaml new file mode 100644 index 000000000..f621f4c4f --- /dev/null +++ b/tests/models/cassettes/test_huggingface/test_hf_model_instructions.yaml @@ -0,0 +1,125 @@ +interactions: +- request: + body: null + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + method: GET + uri: https://huggingface.co/api/models/Qwen/Qwen2.5-72B-Instruct?expand=inferenceProviderMapping + response: + headers: + access-control-allow-origin: + - https://huggingface.co + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '800' + content-type: + - application/json; charset=utf-8 + cross-origin-opener-policy: + - same-origin + etag: + - W/"320-IoLwHc4XKGzRoHW0ok1gY7tY/NI" + referrer-policy: + - strict-origin-when-cross-origin + vary: + - Origin + parsed_body: + _id: 66e81cefd1b1391042d0e47e + id: Qwen/Qwen2.5-72B-Instruct + inferenceProviderMapping: + featherless-ai: + providerId: Qwen/Qwen2.5-72B-Instruct + status: error + task: conversational + fireworks-ai: + providerId: accounts/fireworks/models/qwen2p5-72b-instruct + status: live + task: conversational + hf-inference: + providerId: Qwen/Qwen2.5-72B-Instruct + status: live + task: conversational + hyperbolic: + providerId: Qwen/Qwen2.5-72B-Instruct + status: live + task: conversational + nebius: + providerId: Qwen/Qwen2.5-72B-Instruct-fast + status: live + task: conversational + novita: + providerId: qwen/qwen-2.5-72b-instruct + status: live + task: conversational + together: + providerId: Qwen/Qwen2.5-72B-Instruct-Turbo + status: live + task: conversational + status: + code: 200 + message: OK +- request: + body: null + headers: {} + method: POST + uri: https://router.huggingface.co/nebius/v1/chat/completions + response: + headers: + access-control-allow-credentials: + - 'true' + access-control-allow-origin: + - '*' + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '560' + content-type: + - application/json + cross-origin-opener-policy: + - same-origin + referrer-policy: + - strict-origin-when-cross-origin + strict-transport-security: + - max-age=31536000; includeSubDomains + vary: + - Origin + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + audio: null + content: Paris + function_call: null + reasoning_content: null + refusal: null + role: assistant + tool_calls: [] + stop_reason: null + created: 1749227878 + id: chatcmpl-54246cfb4fa046e88a984020c4efab20 + model: Qwen/Qwen2.5-72B-Instruct-fast + object: chat.completion + prompt_logprobs: null + service_tier: null + system_fingerprint: null + usage: + completion_tokens: 2 + completion_tokens_details: null + prompt_tokens: 26 + prompt_tokens_details: null + total_tokens: 28 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_huggingface/test_request_simple_success_with_vcr.yaml b/tests/models/cassettes/test_huggingface/test_request_simple_success_with_vcr.yaml new file mode 100644 index 000000000..c9a3b50f2 --- /dev/null +++ b/tests/models/cassettes/test_huggingface/test_request_simple_success_with_vcr.yaml @@ -0,0 +1,126 @@ +interactions: +- request: + body: null + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + method: GET + uri: https://huggingface.co/api/models/Qwen/Qwen2.5-72B-Instruct?expand=inferenceProviderMapping + response: + headers: + access-control-allow-origin: + - https://huggingface.co + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '800' + content-type: + - application/json; charset=utf-8 + cross-origin-opener-policy: + - same-origin + etag: + - W/"320-IoLwHc4XKGzRoHW0ok1gY7tY/NI" + referrer-policy: + - strict-origin-when-cross-origin + vary: + - Origin + parsed_body: + _id: 66e81cefd1b1391042d0e47e + id: Qwen/Qwen2.5-72B-Instruct + inferenceProviderMapping: + featherless-ai: + providerId: Qwen/Qwen2.5-72B-Instruct + status: error + task: conversational + fireworks-ai: + providerId: accounts/fireworks/models/qwen2p5-72b-instruct + status: live + task: conversational + hf-inference: + providerId: Qwen/Qwen2.5-72B-Instruct + status: live + task: conversational + hyperbolic: + providerId: Qwen/Qwen2.5-72B-Instruct + status: live + task: conversational + nebius: + providerId: Qwen/Qwen2.5-72B-Instruct-fast + status: live + task: conversational + novita: + providerId: qwen/qwen-2.5-72b-instruct + status: live + task: conversational + together: + providerId: Qwen/Qwen2.5-72B-Instruct-Turbo + status: live + task: conversational + status: + code: 200 + message: OK +- request: + body: null + headers: {} + method: POST + uri: https://router.huggingface.co/nebius/v1/chat/completions + response: + headers: + access-control-allow-credentials: + - 'true' + access-control-allow-origin: + - '*' + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '700' + content-type: + - application/json + cross-origin-opener-policy: + - same-origin + referrer-policy: + - strict-origin-when-cross-origin + strict-transport-security: + - max-age=31536000; includeSubDomains + vary: + - Origin + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + audio: null + content: Hello! It's great to meet you. How can I assist you today? Whether you have questions, need information, + or just want to chat, I'm here to help! + function_call: null + reasoning_content: null + refusal: null + role: assistant + tool_calls: [] + stop_reason: null + created: 1749226637 + id: chatcmpl-f5783ce357b4415b8d59dbbf5b3cf9bf + model: Qwen/Qwen2.5-72B-Instruct-fast + object: chat.completion + prompt_logprobs: null + service_tier: null + system_fingerprint: null + usage: + completion_tokens: 37 + completion_tokens_details: null + prompt_tokens: 30 + prompt_tokens_details: null + total_tokens: 67 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py new file mode 100644 index 000000000..378ab675a --- /dev/null +++ b/tests/models/test_huggingface.py @@ -0,0 +1,692 @@ +from __future__ import annotations as _annotations + +import json +from collections.abc import Sequence +from dataclasses import dataclass +from datetime import datetime, timezone +from functools import cached_property +from typing import Any, Literal, Union, cast +from unittest.mock import Mock + +import pytest +from huggingface_hub import ( + ChatCompletionStreamOutputChoice, + ChatCompletionStreamOutputDelta, + ChatCompletionStreamOutputDeltaToolCall, + ChatCompletionStreamOutputFunction, + ChatCompletionStreamOutputUsage, +) +from inline_snapshot import snapshot +from typing_extensions import TypedDict + +from pydantic_ai import Agent, ModelRetry, UnexpectedModelBehavior +from pydantic_ai.exceptions import ModelHTTPError +from pydantic_ai.messages import ( + BinaryContent, + ImageUrl, + ModelRequest, + ModelResponse, + RetryPromptPart, + SystemPromptPart, + TextPart, + ToolCallPart, + ToolReturnPart, + UserPromptPart, +) +from pydantic_ai.result import Usage +from pydantic_ai.tools import RunContext + +from ..conftest import IsDatetime, IsNow, raise_if_exception, try_import +from .mock_async_stream import MockAsyncStream + +with try_import() as imports_successful: + from huggingface_hub import ( + AsyncInferenceClient, + ChatCompletionInputMessage, + ChatCompletionOutput, + ChatCompletionOutputComplete, + ChatCompletionOutputFunctionDefinition, + ChatCompletionOutputMessage, + ChatCompletionOutputToolCall, + ChatCompletionOutputUsage, + ChatCompletionStreamOutput, + ) + from huggingface_hub.errors import HfHubHTTPError + + from pydantic_ai.models.huggingface import HuggingFaceModel + from pydantic_ai.providers.huggingface import HuggingFaceProvider + + MockChatCompletion = Union[ChatCompletionOutput, Exception] + MockStreamEvent = Union[ChatCompletionStreamOutput, Exception] + +pytestmark = [ + pytest.mark.skipif(not imports_successful(), reason='huggingface_hub not installed'), + pytest.mark.anyio, +] + + +@dataclass +class MockHuggingFace: + completions: MockChatCompletion | Sequence[MockChatCompletion] | None = None + stream: Sequence[MockStreamEvent] | Sequence[Sequence[MockStreamEvent]] | None = None + index: int = 0 + + @cached_property + def chat(self) -> Any: + completions = type('Completions', (), {'create': self.chat_completions_create}) + return type('Chat', (), {'completions': completions}) + + @classmethod + def create_mock(cls, completions: MockChatCompletion | Sequence[MockChatCompletion]) -> AsyncInferenceClient: + return cast(AsyncInferenceClient, cls(completions=completions)) + + @classmethod + def create_stream_mock( + cls, stream: Sequence[MockStreamEvent] | Sequence[Sequence[MockStreamEvent]] + ) -> AsyncInferenceClient: + return cast(AsyncInferenceClient, cls(stream=stream)) + + async def chat_completions_create( + self, *_args: Any, stream: bool = False, **_kwargs: Any + ) -> ChatCompletionOutput | MockAsyncStream[MockStreamEvent]: + if stream or self.stream: + assert self.stream is not None, 'you can only use `stream=True` if `stream` is provided' + if isinstance(self.stream[0], Sequence): + response = MockAsyncStream(iter(cast(list[MockStreamEvent], self.stream[self.index]))) + else: + response = MockAsyncStream(iter(cast(list[MockStreamEvent], self.stream))) + else: + assert self.completions is not None, 'you can only use `stream=False` if `completions` are provided' + if isinstance(self.completions, Sequence): + raise_if_exception(self.completions[self.index]) + response = cast(ChatCompletionOutput, self.completions[self.index]) + else: + raise_if_exception(self.completions) + response = cast(ChatCompletionOutput, self.completions) + self.index += 1 + return response + + +def completion_message( + message: ChatCompletionInputMessage | ChatCompletionOutputMessage, *, usage: ChatCompletionOutputUsage | None = None +) -> ChatCompletionOutput: + choices = [ChatCompletionOutputComplete(finish_reason='stop', index=0, message=message)] # type:ignore + return ChatCompletionOutput.parse_obj_as_instance( # type: ignore + { + 'id': '123', + 'choices': choices, + 'created': 1704067200, # 2024-01-01 + 'model': 'hf-model', + 'object': 'chat.completion', + 'usage': usage, + } + ) + + +async def test_simple_completion(allow_model_requests: None): + c = completion_message(ChatCompletionInputMessage(content='world', role='assistant')) # type:ignore + mock_client = MockHuggingFace.create_mock(c) + model = HuggingFaceModel( + 'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider='nebius', hf_client=mock_client, api_key='x') + ) + agent = Agent(model) + + result = await agent.run('hello') + assert result.output == 'world' + messages = result.all_messages() + request = messages[0] + response = messages[1] + assert request.parts[0].content == 'hello' # type: ignore + assert response == ModelResponse( + parts=[TextPart(content='world')], + usage=Usage(requests=1), + model_name='hf-model', + timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', + ) + + +async def test_request_simple_usage(allow_model_requests: None): + c = completion_message(ChatCompletionInputMessage(content='world', role='assistant')) # type:ignore + mock_client = MockHuggingFace.create_mock(c) + model = HuggingFaceModel( + 'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider='nebius', hf_client=mock_client, api_key='x') + ) + agent = Agent(model) + + result = await agent.run('Hello') + assert result.output == 'world' + assert result.usage() == snapshot(Usage(requests=1)) + + +async def test_request_structured_response(allow_model_requests: None): + tool_call = ChatCompletionOutputToolCall.parse_obj_as_instance( # type:ignore + { + 'function': ChatCompletionOutputFunctionDefinition.parse_obj_as_instance( # type:ignore + { + 'name': 'final_result', + 'arguments': '{"response": [1, 2, 123]}', + } + ), + 'id': '123', + 'type': 'function', + } + ) + message = ChatCompletionOutputMessage.parse_obj_as_instance( # type:ignore + { + 'content': None, + 'role': 'assistant', + 'tool_calls': [tool_call], + } + ) + c = completion_message(message) + + mock_client = MockHuggingFace.create_mock(c) + model = HuggingFaceModel( + 'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider='nebius', hf_client=mock_client, api_key='x') + ) + agent = Agent(model, output_type=list[int]) + + result = await agent.run('Hello') + assert result.output == [1, 2, 123] + messages = result.all_messages() + assert messages[0].parts[0].content == 'Hello' # type: ignore + assert messages[1] == ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args='{"response": [1, 2, 123]}', + tool_call_id='123', + ) + ], + usage=Usage(requests=1), + model_name='hf-model', + timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), + vendor_id='123', + ) + + +async def test_stream_completion(allow_model_requests: None): + stream = [text_chunk('hello '), text_chunk('world', finish_reason='stop')] + mock_client = MockHuggingFace.create_stream_mock(stream) + model = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + agent = Agent(model) + + async with agent.run_stream('') as result: + assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world']) + + +async def test_request_tool_call(allow_model_requests: None): + tool_call_1 = ChatCompletionOutputToolCall.parse_obj_as_instance( # type:ignore + { + 'function': ChatCompletionOutputFunctionDefinition.parse_obj_as_instance( # type:ignore + { + 'name': 'get_location', + 'arguments': '{"loc_name": "San Fransisco"}', + } + ), + 'id': '1', + 'type': 'function', + } + ) + usage_1 = ChatCompletionOutputUsage.parse_obj_as_instance( # type:ignore + { + 'prompt_tokens': 1, + 'completion_tokens': 1, + 'total_tokens': 2, + } + ) + tool_call_2 = ChatCompletionOutputToolCall.parse_obj_as_instance( # type:ignore + { + 'function': ChatCompletionOutputFunctionDefinition.parse_obj_as_instance( # type:ignore + { + 'name': 'get_location', + 'arguments': '{"loc_name": "London"}', + } + ), + 'id': '2', + 'type': 'function', + } + ) + usage_2 = ChatCompletionOutputUsage.parse_obj_as_instance( # type:ignore + { + 'prompt_tokens': 2, + 'completion_tokens': 1, + 'total_tokens': 3, + } + ) + responses = [ + completion_message( + ChatCompletionOutputMessage.parse_obj_as_instance( # type:ignore + { + 'content': None, + 'role': 'assistant', + 'tool_calls': [tool_call_1], + } + ), + usage=usage_1, + ), + completion_message( + ChatCompletionOutputMessage.parse_obj_as_instance( # type:ignore + { + 'content': None, + 'role': 'assistant', + 'tool_calls': [tool_call_2], + } + ), + usage=usage_2, + ), + completion_message( + ChatCompletionOutputMessage.parse_obj_as_instance( # type:ignore + { + 'content': 'final response', + 'role': 'assistant', + } + ), + ), + ] + mock_client = MockHuggingFace.create_mock(responses) + model = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + agent = Agent(model, system_prompt='this is the system prompt') + + @agent.tool_plain + async def get_location(loc_name: str) -> str: + if loc_name == 'London': + return json.dumps({'lat': 51, 'lng': 0}) + else: + raise ModelRetry('Wrong location, please try again') + + result = await agent.run('Hello') + assert result.output == 'final response' + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + SystemPromptPart(content='this is the system prompt', timestamp=IsNow(tz=timezone.utc)), + UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='get_location', + args='{"loc_name": "San Fransisco"}', + tool_call_id='1', + ) + ], + usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=2), + model_name='hf-model', + timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', + ), + ModelRequest( + parts=[ + RetryPromptPart( + content='Wrong location, please try again', + tool_name='get_location', + tool_call_id='1', + timestamp=IsNow(tz=timezone.utc), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='get_location', + args='{"loc_name": "London"}', + tool_call_id='2', + ) + ], + usage=Usage(requests=1, request_tokens=2, response_tokens=1, total_tokens=3), + model_name='hf-model', + timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_location', + content='{"lat": 51, "lng": 0}', + tool_call_id='2', + timestamp=IsNow(tz=timezone.utc), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='final response')], + usage=Usage(requests=1), + model_name='hf-model', + timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', + ), + ] + ) + + +FinishReason = Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call'] + + +def chunk( + delta: list[ChatCompletionStreamOutputDelta], finish_reason: FinishReason | None = None +) -> ChatCompletionStreamOutput: + return ChatCompletionStreamOutput.parse_obj_as_instance( # type: ignore + { + 'id': 'x', + 'choices': [ + ChatCompletionStreamOutputChoice(index=index, delta=delta, finish_reason=finish_reason) # type: ignore + for index, delta in enumerate(delta) + ], + 'created': 1704067200, # 2024-01-01 + 'model': 'hf-model', + 'object': 'chat.completion.chunk', + 'usage': ChatCompletionStreamOutputUsage(completion_tokens=1, prompt_tokens=2, total_tokens=3), # type: ignore + } + ) + + +def text_chunk(text: str, finish_reason: FinishReason | None = None) -> ChatCompletionStreamOutput: + return chunk([ChatCompletionStreamOutputDelta(content=text, role='assistant')], finish_reason=finish_reason) # type: ignore + + +async def test_stream_text(allow_model_requests: None): + stream = [text_chunk('hello '), text_chunk('world'), chunk([])] + mock_client = MockHuggingFace.create_stream_mock(stream) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client)) + agent = Agent(m) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world']) + assert result.is_complete + assert result.usage() == snapshot(Usage(requests=1, request_tokens=6, response_tokens=3, total_tokens=9)) + + +async def test_stream_text_finish_reason(allow_model_requests: None): + stream = [ + text_chunk('hello '), + text_chunk('world'), + text_chunk('.', finish_reason='stop'), + ] + mock_client = MockHuggingFace.create_stream_mock(stream) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client)) + agent = Agent(m) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [c async for c in result.stream_text(debounce_by=None)] == snapshot( + ['hello ', 'hello world', 'hello world.'] + ) + assert result.is_complete + + +def struc_chunk( + tool_name: str | None, tool_arguments: str | None, finish_reason: FinishReason | None = None +) -> ChatCompletionStreamOutput: + return chunk( + [ + ChatCompletionStreamOutputDelta.parse_obj_as_instance( # type: ignore + { + 'role': 'assistant', + 'tool_calls': [ + ChatCompletionStreamOutputDeltaToolCall.parse_obj_as_instance( # type: ignore + { + 'index': 0, + 'function': ChatCompletionStreamOutputFunction.parse_obj_as_instance( # type: ignore + { + 'name': tool_name, + 'arguments': tool_arguments, + } + ), + } + ) + ], + } + ), + ], + finish_reason=finish_reason, + ) + + +class MyTypedDict(TypedDict, total=False): + first: str + second: str + + +async def test_stream_structured(allow_model_requests: None): + stream = [ + chunk([ChatCompletionStreamOutputDelta(role='assistant')]), # type: ignore + chunk([ChatCompletionStreamOutputDelta(role='assistant', tool_calls=[])]), # type: ignore + chunk( + [ + ChatCompletionStreamOutputDelta( + role='assistant', # type: ignore + tool_calls=[ # type: ignore + ChatCompletionStreamOutputDeltaToolCall(id='0', type='function', index=0, function=None) # type: ignore + ], + ) + ] + ), + chunk( + [ + ChatCompletionStreamOutputDelta( + role='assistant', # type: ignore + tool_calls=[ # type: ignore + ChatCompletionStreamOutputDeltaToolCall(id='0', type='function', index=0, function=None) # type: ignore + ], + ) + ] + ), + struc_chunk('final_result', None), + chunk( + [ + ChatCompletionStreamOutputDelta( + role='assistant', # type: ignore + tool_calls=[ # type: ignore + ChatCompletionStreamOutputDeltaToolCall(id='0', type='function', index=0, function=None) # type: ignore + ], + ) + ] + ), + struc_chunk(None, '{"first": "One'), + struc_chunk(None, '", "second": "Two"'), + struc_chunk(None, '}'), + chunk([]), + ] + mock_client = MockHuggingFace.create_stream_mock(stream) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client)) + agent = Agent(m, output_type=MyTypedDict) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [dict(c) async for c in result.stream(debounce_by=None)] == snapshot( + [ + {}, + {'first': 'One'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + ] + ) + assert result.is_complete + assert result.usage() == snapshot(Usage(requests=1, request_tokens=20, response_tokens=10, total_tokens=30)) + # double check usage matches stream count + assert result.usage().response_tokens == len(stream) + + +async def test_stream_structured_finish_reason(allow_model_requests: None): + stream = [ + struc_chunk('final_result', None), + struc_chunk(None, '{"first": "One'), + struc_chunk(None, '", "second": "Two"'), + struc_chunk(None, '}'), + struc_chunk(None, None, finish_reason='stop'), + ] + mock_client = MockHuggingFace.create_stream_mock(stream) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client)) + agent = Agent(m, output_type=MyTypedDict) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [dict(c) async for c in result.stream(debounce_by=None)] == snapshot( + [ + {'first': 'One'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + ] + ) + assert result.is_complete + + +async def test_no_content(allow_model_requests: None): + stream = [ + chunk([ChatCompletionStreamOutputDelta(role='assistant')]), # type: ignore + chunk([ChatCompletionStreamOutputDelta(role='assistant')]), # type: ignore + ] + mock_client = MockHuggingFace.create_stream_mock(stream) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client)) + agent = Agent(m, output_type=MyTypedDict) + + with pytest.raises(UnexpectedModelBehavior, match='Received empty model response'): + async with agent.run_stream(''): + pass + + +async def test_no_delta(allow_model_requests: None): + stream = [ + chunk([]), + text_chunk('hello '), + text_chunk('world'), + ] + mock_client = MockHuggingFace.create_stream_mock(stream) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client)) + agent = Agent(m) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world']) + assert result.is_complete + assert result.usage() == snapshot(Usage(requests=1, request_tokens=6, response_tokens=3, total_tokens=9)) + + +async def test_image_url_input(allow_model_requests: None): + c = completion_message(ChatCompletionInputMessage(content='world', role='assistant')) # type:ignore + mock_client = MockHuggingFace.create_mock(c) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client)) + agent = Agent(m) + + result = await agent.run( + [ + 'hello', + ImageUrl(url='https://t3.ftcdn.net/jpg/00/85/79/92/360_F_85799278_0BBGV9OAdQDTLnKwAPBCcg1J7QtiieJY.jpg'), + ] + ) + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content=[ + 'hello', + ImageUrl( + url='https://t3.ftcdn.net/jpg/00/85/79/92/360_F_85799278_0BBGV9OAdQDTLnKwAPBCcg1J7QtiieJY.jpg' + ), + ], + timestamp=IsNow(tz=timezone.utc), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='world')], + usage=Usage(requests=1), + model_name='hf-model', + timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', + ), + ] + ) + + +async def test_image_as_binary_content_input(allow_model_requests: None): + c = completion_message(ChatCompletionInputMessage(content='world', role='assistant')) # type: ignore + mock_client = MockHuggingFace.create_mock(c) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client)) + agent = Agent(m) + + base64_content = ( + b'/9j/4AAQSkZJRgABAQEAYABgAAD/4QBYRXhpZgAATU0AKgAAAAgAA1IBAAEAAAABAAAAPgIBAAEAAAABAAAARgMBAAEAAAABAAAA' + b'WgAAAAAAAAAE' + ) + + result = await agent.run(['hello', BinaryContent(data=base64_content, media_type='image/jpeg')]) + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content=['hello', BinaryContent(data=base64_content, media_type='image/jpeg')], + timestamp=IsNow(tz=timezone.utc), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='world')], + usage=Usage(requests=1), + model_name='hf-model', + timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', + ), + ] + ) + + +def test_model_status_error(allow_model_requests: None) -> None: + error = HfHubHTTPError(message='test_error', response=Mock(status_code=500, content={'error': 'test error'})) + mock_client = MockHuggingFace.create_mock(error) + m = HuggingFaceModel('not_a_model', provider=HuggingFaceProvider(hf_client=mock_client)) + agent = Agent(m) + with pytest.raises(ModelHTTPError) as exc_info: + agent.run_sync('hello') + assert str(exc_info.value) == snapshot("status_code: 500, model_name: not_a_model, body: {'error': 'test error'}") + + +@pytest.mark.vcr() +async def test_request_simple_success_with_vcr(allow_model_requests: None, huggingface_api_key: str): + m = HuggingFaceModel( + 'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider='nebius', api_key=huggingface_api_key) + ) + agent = Agent(m) + result = await agent.run('hello') + assert result.output == snapshot( + "Hello! It's great to meet you. How can I assist you today? Whether you have questions, need information, or just want to chat, I'm here to help!" + ) + + +@pytest.mark.vcr() +async def test_hf_model_instructions(allow_model_requests: None, huggingface_api_key: str): + m = HuggingFaceModel( + 'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider='nebius', api_key=huggingface_api_key) + ) + + def simple_instructions(ctx: RunContext): + return 'You are a helpful assistant.' + + agent = Agent(m, instructions=simple_instructions) + + result = await agent.run('What is the capital of France?') + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='What is the capital of France?', timestamp=IsDatetime())], + instructions='You are a helpful assistant.', + ), + ModelResponse( + parts=[TextPart(content='Paris')], + usage=Usage(requests=1, request_tokens=26, response_tokens=2, total_tokens=28), + model_name='Qwen/Qwen2.5-72B-Instruct-fast', + timestamp=IsDatetime(), + vendor_id='chatcmpl-54246cfb4fa046e88a984020c4efab20', + ), + ] + ) diff --git a/tests/providers/test_huggingface.py b/tests/providers/test_huggingface.py new file mode 100644 index 000000000..5b52bfe23 --- /dev/null +++ b/tests/providers/test_huggingface.py @@ -0,0 +1,61 @@ +from __future__ import annotations as _annotations + +import re + +import httpx +import pytest + +from pydantic_ai.exceptions import UserError + +from ..conftest import TestEnv, try_import + +with try_import() as imports_successful: + from huggingface_hub import AsyncInferenceClient + + from pydantic_ai.providers.huggingface import HuggingFaceProvider + + +pytestmark = pytest.mark.skipif(not imports_successful(), reason='huggingface_hub not installed') + + +def test_huggingface_provider(): + hf_client = AsyncInferenceClient(api_key='api-key') + provider = HuggingFaceProvider(api_key='api-key', hf_client=hf_client) + assert provider.name == 'huggingface' + assert isinstance(provider.client, AsyncInferenceClient) + assert provider.client.token == 'api-key' + + +def test_huggingface_provider_need_api_key(env: TestEnv) -> None: + env.remove('HF_TOKEN') + with pytest.raises( + UserError, + match=re.escape( + 'Set the `HF_TOKEN` environment variable or pass it via `HuggingFaceProvider(api_key=...)`' + 'to use the HuggingFace provider.' + ), + ): + HuggingFaceProvider() + + +def test_huggingface_provider_pass_http_client() -> None: + http_client = httpx.AsyncClient() + with pytest.raises( + ValueError, + match=re.escape('`http_client` is ignored for HuggingFace provider, please use `hf_client` instead'), + ): + HuggingFaceProvider(http_client=http_client, api_key='api-key') + + +def test_huggingface_provider_pass_hf_client() -> None: + hf_client = AsyncInferenceClient(api_key='api-key') + provider = HuggingFaceProvider(hf_client=hf_client) + assert provider.client == hf_client + + +def test_hf_provider_with_base_url() -> None: + # Test with environment variable for base_url + provider = HuggingFaceProvider( + hf_client=AsyncInferenceClient(api_key='test-api-key', base_url='https://router.huggingface.co/nebius/v1'), + ) + assert provider.base_url == 'https://router.huggingface.co/nebius/v1' From 40aef2e6f97f1581d5bb4c419852c6a5af4e698f Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Mon, 9 Jun 2025 13:49:44 +0100 Subject: [PATCH 06/18] add docs and known models for hf --- docs/models/huggingface.md | 84 +++++++++++++++++++ .../pydantic_ai/models/__init__.py | 14 +++- .../pydantic_ai/models/huggingface.py | 23 ++++- tests/models/test_model_names.py | 3 + 4 files changed, 121 insertions(+), 3 deletions(-) create mode 100644 docs/models/huggingface.md diff --git a/docs/models/huggingface.md b/docs/models/huggingface.md new file mode 100644 index 000000000..8edeb8cc3 --- /dev/null +++ b/docs/models/huggingface.md @@ -0,0 +1,84 @@ +# Hugging Face + + +## Install + +To use `HuggingFace`, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `huggingface` optional group: + +```bash +pip/uv-add "pydantic-ai-slim[huggingface]" +``` + +## Configuration + +To use `HuggingFaceModel` through their main API, go to [Inference Providers documentation](https://huggingface.co/docs/inference-providers/pricing) for all the details, and you can generate a Hugging Face Token here: https://huggingface.co/settings/tokens. + +## Environment variable + +Once you have a HuggingFace Token, you can set it as an environment variable: + +```bash +export HF_TOKEN='your-hf-token' +``` + +You can then use `HuggingFaceModel` by name: + +```python +from pydantic_ai import Agent + +agent = Agent('huggingface:Qwen/Qwen3-235B-A22B') +... +``` + +Or initialise the model directly with just the model name: + +```python +from pydantic_ai import Agent +from pydantic_ai.models.huggingface import HuggingFaceModel + +model = HuggingFaceModel('Qwen/Qwen3-235B-A22B') +agent = Agent(model) +... +``` + +By default, the `HuggingFaceModel` uses the `HuggingFaceProvider` that will select automatically the first of the inference providers (Cerebras, Together AI, Cohere..etc) available for the model, sorted by your preferred order in https://hf.co/settings/inference-providers. + +## Configure the provider + +If you want to pass parameters in code to the provider, you can programmatically instantiate the +[HuggingFaceProvider][pydantic_ai.providers.huggingface.HuggingFaceProvider] and pass it to the model: + +```python +from pydantic_ai import Agent +from pydantic_ai.models.huggingface import HuggingFaceModel +from pydantic_ai.providers.huggingface import HuggingFaceProvider + +model = HuggingFaceModel('Qwen/Qwen3-235B-A22B', provider=HuggingFaceProvider(api_key='your-api-key', provider="nebius")) +agent = Agent(model) +... +``` + +## Custom Hugging Face Client + +`HuggingFaceProvider` also accepts a custom `AsyncInferenceClient` client via the `hf_client` parameter, so you can customise the `headers`, `bill_to` (billing to an HF organization you're a member of), `base_url` etc. as defined in the [Hugging Face Hub python library docs](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client). + +```python +from huggingface_hub import AsyncInferenceClient + +from pydantic_ai import Agent +from pydantic_ai.models.huggingface import HuggingFaceModel +from pydantic_ai.providers.huggingface import HuggingFaceProvider + +client = AsyncInferenceClient( + bill_to="openai", + api_key='your-api-key', + provider="fireworks-ai", +) + +model = HuggingFaceModel( + 'Qwen/Qwen3-235B-A22B', + provider=HuggingFaceProvider(hf_client=client), +) +agent = Agent(model) +... +``` diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 92b919181..e42be9dfc 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -204,6 +204,14 @@ 'groq:llama-3.2-3b-preview', 'groq:llama-3.2-11b-vision-preview', 'groq:llama-3.2-90b-vision-preview', + 'huggingface:Qwen/QwQ-32B', + 'huggingface:Qwen/Qwen2.5-72B-Instruct', + 'huggingface:Qwen/Qwen3-235B-A22B', + 'huggingface:Qwen/Qwen3-32B', + 'huggingface:deepseek-ai/DeepSeek-R1', + 'huggingface:meta-llama/Llama-3.3-70B-Instruct', + 'huggingface:meta-llama/Llama-4-Maverick-17B-128E-Instruct', + 'huggingface:meta-llama/Llama-4-Scout-17B-16E-Instruct', 'mistral:codestral-latest', 'mistral:mistral-large-latest', 'mistral:mistral-moderation-latest', @@ -485,7 +493,7 @@ def override_allow_model_requests(allow_model_requests: bool) -> Iterator[None]: ALLOW_MODEL_REQUESTS = old_value # pyright: ignore[reportConstantRedefinition] -def infer_model(model: Model | KnownModelName | str) -> Model: +def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901 """Infer the model from the name.""" if isinstance(model, Model): return model @@ -539,6 +547,10 @@ def infer_model(model: Model | KnownModelName | str) -> Model: from .bedrock import BedrockConverseModel return BedrockConverseModel(model_name, provider=provider) + elif provider == 'huggingface': + from .huggingface import HuggingFaceModel + + return HuggingFaceModel(model_name, provider=provider) else: raise UserError(f'Unknown model: {model}') # pragma: no cover diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index 1dc1db73c..5de42a270 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -5,7 +5,7 @@ from contextlib import asynccontextmanager from dataclasses import dataclass, field from datetime import datetime, timezone -from typing import Literal, cast, overload +from typing import Literal, Union, cast, overload from typing_extensions import assert_never @@ -65,6 +65,25 @@ HFSystemPromptRole = Literal['system', 'user'] +LatestHuggingFaceModelNames = Literal[ + 'deepseek-ai/DeepSeek-R1', + 'meta-llama/Llama-3.3-70B-Instruct', + 'meta-llama/Llama-4-Maverick-17B-128E-Instruct', + 'meta-llama/Llama-4-Scout-17B-16E-Instruct', + 'Qwen/QwQ-32B', + 'Qwen/Qwen2.5-72B-Instruct', + 'Qwen/Qwen3-235B-A22B', + 'Qwen/Qwen3-32B', +] +"""Latest Hugging Face models.""" + + +HuggingFaceModelName = Union[str, LatestHuggingFaceModelNames] +"""Possible Hugging Face model names. + +You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending). +""" + class HuggingFaceModelSettings(ModelSettings, total=False): """Settings used for a Hugging Face model request. @@ -136,7 +155,7 @@ async def request_stream( yield await self._process_streamed_response(response) @property - def model_name(self) -> str: + def model_name(self) -> HuggingFaceModelName: """The model name.""" return self._model_name diff --git a/tests/models/test_model_names.py b/tests/models/test_model_names.py index 63eb7c650..da53d4dd9 100644 --- a/tests/models/test_model_names.py +++ b/tests/models/test_model_names.py @@ -14,6 +14,7 @@ from pydantic_ai.models.cohere import CohereModelName from pydantic_ai.models.gemini import GeminiModelName from pydantic_ai.models.groq import GroqModelName + from pydantic_ai.models.huggingface import HuggingFaceModelName from pydantic_ai.models.mistral import MistralModelName from pydantic_ai.models.openai import OpenAIModelName @@ -44,6 +45,7 @@ def get_model_names(model_name_type: Any) -> Iterator[str]: ] bedrock_names = [f'bedrock:{n}' for n in get_model_names(BedrockModelName)] deepseek_names = ['deepseek:deepseek-chat', 'deepseek:deepseek-reasoner'] + huggingface_names = [f'huggingface:{n}' for n in get_model_names(HuggingFaceModelName)] extra_names = ['test'] generated_names = sorted( @@ -55,6 +57,7 @@ def get_model_names(model_name_type: Any) -> Iterator[str]: + openai_names + bedrock_names + deepseek_names + + huggingface_names + extra_names ) From 7a4b9a4ce3f0f0a5b8816d3e38251db37b52e502 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Mon, 9 Jun 2025 14:10:47 +0100 Subject: [PATCH 07/18] fix imports in test --- tests/models/test_huggingface.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py index 378ab675a..1fd8b63ea 100644 --- a/tests/models/test_huggingface.py +++ b/tests/models/test_huggingface.py @@ -9,13 +9,6 @@ from unittest.mock import Mock import pytest -from huggingface_hub import ( - ChatCompletionStreamOutputChoice, - ChatCompletionStreamOutputDelta, - ChatCompletionStreamOutputDeltaToolCall, - ChatCompletionStreamOutputFunction, - ChatCompletionStreamOutputUsage, -) from inline_snapshot import snapshot from typing_extensions import TypedDict @@ -50,6 +43,11 @@ ChatCompletionOutputToolCall, ChatCompletionOutputUsage, ChatCompletionStreamOutput, + ChatCompletionStreamOutputChoice, + ChatCompletionStreamOutputDelta, + ChatCompletionStreamOutputDeltaToolCall, + ChatCompletionStreamOutputFunction, + ChatCompletionStreamOutputUsage, ) from huggingface_hub.errors import HfHubHTTPError From a1530818c6328161f0d892a352c1fec894901e6e Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Mon, 9 Jun 2025 14:17:54 +0100 Subject: [PATCH 08/18] fix tests --- tests/models/test_huggingface.py | 18 +++++++++--------- tests/providers/test_huggingface.py | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py index 1fd8b63ea..731728b7b 100644 --- a/tests/models/test_huggingface.py +++ b/tests/models/test_huggingface.py @@ -389,7 +389,7 @@ def text_chunk(text: str, finish_reason: FinishReason | None = None) -> ChatComp async def test_stream_text(allow_model_requests: None): stream = [text_chunk('hello '), text_chunk('world'), chunk([])] mock_client = MockHuggingFace.create_stream_mock(stream) - m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client)) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) agent = Agent(m) async with agent.run_stream('') as result: @@ -406,7 +406,7 @@ async def test_stream_text_finish_reason(allow_model_requests: None): text_chunk('.', finish_reason='stop'), ] mock_client = MockHuggingFace.create_stream_mock(stream) - m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client)) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) agent = Agent(m) async with agent.run_stream('') as result: @@ -491,7 +491,7 @@ async def test_stream_structured(allow_model_requests: None): chunk([]), ] mock_client = MockHuggingFace.create_stream_mock(stream) - m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client)) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) agent = Agent(m, output_type=MyTypedDict) async with agent.run_stream('') as result: @@ -520,7 +520,7 @@ async def test_stream_structured_finish_reason(allow_model_requests: None): struc_chunk(None, None, finish_reason='stop'), ] mock_client = MockHuggingFace.create_stream_mock(stream) - m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client)) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) agent = Agent(m, output_type=MyTypedDict) async with agent.run_stream('') as result: @@ -543,7 +543,7 @@ async def test_no_content(allow_model_requests: None): chunk([ChatCompletionStreamOutputDelta(role='assistant')]), # type: ignore ] mock_client = MockHuggingFace.create_stream_mock(stream) - m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client)) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) agent = Agent(m, output_type=MyTypedDict) with pytest.raises(UnexpectedModelBehavior, match='Received empty model response'): @@ -558,7 +558,7 @@ async def test_no_delta(allow_model_requests: None): text_chunk('world'), ] mock_client = MockHuggingFace.create_stream_mock(stream) - m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client)) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) agent = Agent(m) async with agent.run_stream('') as result: @@ -571,7 +571,7 @@ async def test_no_delta(allow_model_requests: None): async def test_image_url_input(allow_model_requests: None): c = completion_message(ChatCompletionInputMessage(content='world', role='assistant')) # type:ignore mock_client = MockHuggingFace.create_mock(c) - m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client)) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) agent = Agent(m) result = await agent.run( @@ -609,7 +609,7 @@ async def test_image_url_input(allow_model_requests: None): async def test_image_as_binary_content_input(allow_model_requests: None): c = completion_message(ChatCompletionInputMessage(content='world', role='assistant')) # type: ignore mock_client = MockHuggingFace.create_mock(c) - m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client)) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) agent = Agent(m) base64_content = ( @@ -642,7 +642,7 @@ async def test_image_as_binary_content_input(allow_model_requests: None): def test_model_status_error(allow_model_requests: None) -> None: error = HfHubHTTPError(message='test_error', response=Mock(status_code=500, content={'error': 'test error'})) mock_client = MockHuggingFace.create_mock(error) - m = HuggingFaceModel('not_a_model', provider=HuggingFaceProvider(hf_client=mock_client)) + m = HuggingFaceModel('not_a_model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) agent = Agent(m) with pytest.raises(ModelHTTPError) as exc_info: agent.run_sync('hello') diff --git a/tests/providers/test_huggingface.py b/tests/providers/test_huggingface.py index 5b52bfe23..4df8d8f19 100644 --- a/tests/providers/test_huggingface.py +++ b/tests/providers/test_huggingface.py @@ -56,6 +56,6 @@ def test_huggingface_provider_pass_hf_client() -> None: def test_hf_provider_with_base_url() -> None: # Test with environment variable for base_url provider = HuggingFaceProvider( - hf_client=AsyncInferenceClient(api_key='test-api-key', base_url='https://router.huggingface.co/nebius/v1'), + hf_client=AsyncInferenceClient(base_url='https://router.huggingface.co/nebius/v1'), api_key='test-api-key' ) assert provider.base_url == 'https://router.huggingface.co/nebius/v1' From 2f0ec5189dcd0607e24a2b61f15e745a4f835bef Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Mon, 9 Jun 2025 14:20:06 +0100 Subject: [PATCH 09/18] fix provider test --- tests/providers/test_huggingface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/providers/test_huggingface.py b/tests/providers/test_huggingface.py index 4df8d8f19..970c9d636 100644 --- a/tests/providers/test_huggingface.py +++ b/tests/providers/test_huggingface.py @@ -49,7 +49,7 @@ def test_huggingface_provider_pass_http_client() -> None: def test_huggingface_provider_pass_hf_client() -> None: hf_client = AsyncInferenceClient(api_key='api-key') - provider = HuggingFaceProvider(hf_client=hf_client) + provider = HuggingFaceProvider(hf_client=hf_client, api_key='api-key') assert provider.client == hf_client From 69aee552602f6c382677a50d52410a83aaa4a9f0 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Mon, 9 Jun 2025 14:21:16 +0100 Subject: [PATCH 10/18] adapt cli test --- tests/test_cli.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_cli.py b/tests/test_cli.py index 024116249..8efc0da00 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -144,6 +144,7 @@ def test_list_models(capfd: CaptureFixture[str]): 'cohere', 'deepseek', 'heroku', + 'huggingface', ) models = {line.strip().split(' ')[0] for line in output[3:]} for provider in providers: From f68dacea3ba1315abe9e463548313c635b351334 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Mon, 9 Jun 2025 14:26:51 +0100 Subject: [PATCH 11/18] re-record vcr cassettes --- .../test_hf_model_instructions.yaml | 70 +------------------ .../test_request_simple_success_with_vcr.yaml | 14 ++-- tests/models/test_huggingface.py | 4 +- 3 files changed, 11 insertions(+), 77 deletions(-) diff --git a/tests/models/cassettes/test_huggingface/test_hf_model_instructions.yaml b/tests/models/cassettes/test_huggingface/test_hf_model_instructions.yaml index f621f4c4f..11bcb7596 100644 --- a/tests/models/cassettes/test_huggingface/test_hf_model_instructions.yaml +++ b/tests/models/cassettes/test_huggingface/test_hf_model_instructions.yaml @@ -1,70 +1,4 @@ interactions: -- request: - body: null - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - method: GET - uri: https://huggingface.co/api/models/Qwen/Qwen2.5-72B-Instruct?expand=inferenceProviderMapping - response: - headers: - access-control-allow-origin: - - https://huggingface.co - access-control-expose-headers: - - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash - connection: - - keep-alive - content-length: - - '800' - content-type: - - application/json; charset=utf-8 - cross-origin-opener-policy: - - same-origin - etag: - - W/"320-IoLwHc4XKGzRoHW0ok1gY7tY/NI" - referrer-policy: - - strict-origin-when-cross-origin - vary: - - Origin - parsed_body: - _id: 66e81cefd1b1391042d0e47e - id: Qwen/Qwen2.5-72B-Instruct - inferenceProviderMapping: - featherless-ai: - providerId: Qwen/Qwen2.5-72B-Instruct - status: error - task: conversational - fireworks-ai: - providerId: accounts/fireworks/models/qwen2p5-72b-instruct - status: live - task: conversational - hf-inference: - providerId: Qwen/Qwen2.5-72B-Instruct - status: live - task: conversational - hyperbolic: - providerId: Qwen/Qwen2.5-72B-Instruct - status: live - task: conversational - nebius: - providerId: Qwen/Qwen2.5-72B-Instruct-fast - status: live - task: conversational - novita: - providerId: qwen/qwen-2.5-72b-instruct - status: live - task: conversational - together: - providerId: Qwen/Qwen2.5-72B-Instruct-Turbo - status: live - task: conversational - status: - code: 200 - message: OK - request: body: null headers: {} @@ -106,8 +40,8 @@ interactions: role: assistant tool_calls: [] stop_reason: null - created: 1749227878 - id: chatcmpl-54246cfb4fa046e88a984020c4efab20 + created: 1749475551 + id: chatcmpl-6fa46f85f4f04beda9c936d5996b22a8 model: Qwen/Qwen2.5-72B-Instruct-fast object: chat.completion prompt_logprobs: null diff --git a/tests/models/cassettes/test_huggingface/test_request_simple_success_with_vcr.yaml b/tests/models/cassettes/test_huggingface/test_request_simple_success_with_vcr.yaml index c9a3b50f2..6996da033 100644 --- a/tests/models/cassettes/test_huggingface/test_request_simple_success_with_vcr.yaml +++ b/tests/models/cassettes/test_huggingface/test_request_simple_success_with_vcr.yaml @@ -81,7 +81,7 @@ interactions: connection: - keep-alive content-length: - - '700' + - '680' content-type: - application/json cross-origin-opener-policy: @@ -99,27 +99,27 @@ interactions: logprobs: null message: audio: null - content: Hello! It's great to meet you. How can I assist you today? Whether you have questions, need information, - or just want to chat, I'm here to help! + content: Hello! How can I assist you today? Feel free to ask me any questions or let me know if you need help with + anything specific. function_call: null reasoning_content: null refusal: null role: assistant tool_calls: [] stop_reason: null - created: 1749226637 - id: chatcmpl-f5783ce357b4415b8d59dbbf5b3cf9bf + created: 1749475549 + id: chatcmpl-6050852c70164258bb9bab4e93e2b69c model: Qwen/Qwen2.5-72B-Instruct-fast object: chat.completion prompt_logprobs: null service_tier: null system_fingerprint: null usage: - completion_tokens: 37 + completion_tokens: 29 completion_tokens_details: null prompt_tokens: 30 prompt_tokens_details: null - total_tokens: 67 + total_tokens: 59 status: code: 200 message: OK diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py index 731728b7b..384328adf 100644 --- a/tests/models/test_huggingface.py +++ b/tests/models/test_huggingface.py @@ -657,7 +657,7 @@ async def test_request_simple_success_with_vcr(allow_model_requests: None, huggi agent = Agent(m) result = await agent.run('hello') assert result.output == snapshot( - "Hello! It's great to meet you. How can I assist you today? Whether you have questions, need information, or just want to chat, I'm here to help!" + 'Hello! How can I assist you today? Feel free to ask me any questions or let me know if you need help with anything specific.' ) @@ -684,7 +684,7 @@ def simple_instructions(ctx: RunContext): usage=Usage(requests=1, request_tokens=26, response_tokens=2, total_tokens=28), model_name='Qwen/Qwen2.5-72B-Instruct-fast', timestamp=IsDatetime(), - vendor_id='chatcmpl-54246cfb4fa046e88a984020c4efab20', + vendor_id='chatcmpl-6fa46f85f4f04beda9c936d5996b22a8', ), ] ) From cc982e5271ce10ae147d34c12e8068df8437fae2 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Mon, 9 Jun 2025 14:35:11 +0100 Subject: [PATCH 12/18] fix token name --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index a2c730a2a..1240bab6f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -283,7 +283,7 @@ def openrouter_api_key() -> str: @pytest.fixture(scope='session') def huggingface_api_key() -> str: - return os.getenv('HF_TOKEN', 'mock-api-key') or os.getenv('HUGGINGFACE_API_KEY', 'mock-api-key') + return os.getenv('HF_TOKEN', 'hf_token') or os.getenv('HUGGINGFACE_API_KEY', 'hf_token') @pytest.fixture(scope='session') From 00da46ecf0cf92374df4c5a82b714c279bd67ff3 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Mon, 9 Jun 2025 14:50:23 +0100 Subject: [PATCH 13/18] fix examples test --- docs/models/huggingface.md | 10 +++++----- tests/test_examples.py | 1 + 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/docs/models/huggingface.md b/docs/models/huggingface.md index 8edeb8cc3..8d10a7ea8 100644 --- a/docs/models/huggingface.md +++ b/docs/models/huggingface.md @@ -18,7 +18,7 @@ To use `HuggingFaceModel` through their main API, go to [Inference Providers doc Once you have a HuggingFace Token, you can set it as an environment variable: ```bash -export HF_TOKEN='your-hf-token' +export HF_TOKEN='hf_token' ``` You can then use `HuggingFaceModel` by name: @@ -53,7 +53,7 @@ from pydantic_ai import Agent from pydantic_ai.models.huggingface import HuggingFaceModel from pydantic_ai.providers.huggingface import HuggingFaceProvider -model = HuggingFaceModel('Qwen/Qwen3-235B-A22B', provider=HuggingFaceProvider(api_key='your-api-key', provider="nebius")) +model = HuggingFaceModel('Qwen/Qwen3-235B-A22B', provider=HuggingFaceProvider(api_key='hf_token', provider='nebius')) agent = Agent(model) ... ``` @@ -70,9 +70,9 @@ from pydantic_ai.models.huggingface import HuggingFaceModel from pydantic_ai.providers.huggingface import HuggingFaceProvider client = AsyncInferenceClient( - bill_to="openai", - api_key='your-api-key', - provider="fireworks-ai", + bill_to='openai', + api_key='hf_token', + provider='fireworks-ai', ) model = HuggingFaceModel( diff --git a/tests/test_examples.py b/tests/test_examples.py index ad377bedb..977f336f0 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -137,6 +137,7 @@ def print(self, *args: Any, **kwargs: Any) -> None: env.set('CO_API_KEY', 'testing') env.set('MISTRAL_API_KEY', 'testing') env.set('ANTHROPIC_API_KEY', 'testing') + env.set('HF_TOKEN', 'hf_testing') env.set('AWS_ACCESS_KEY_ID', 'testing') env.set('AWS_SECRET_ACCESS_KEY', 'testing') env.set('AWS_DEFAULT_REGION', 'us-east-1') From 922fd13161f3a06cd6f95f43d28cb207e278b33e Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 25 Jun 2025 12:53:40 +0200 Subject: [PATCH 14/18] Add API docs and refactor a bit the wording --- docs/api/models/huggingface.md | 7 +++++++ docs/api/providers.md | 2 ++ docs/models/huggingface.md | 27 +++++++++++++++++---------- mkdocs.yml | 1 + 4 files changed, 27 insertions(+), 10 deletions(-) create mode 100644 docs/api/models/huggingface.md diff --git a/docs/api/models/huggingface.md b/docs/api/models/huggingface.md new file mode 100644 index 000000000..72e78c4a3 --- /dev/null +++ b/docs/api/models/huggingface.md @@ -0,0 +1,7 @@ +# `pydantic_ai.models.huggingface` + +## Setup + +For details on how to set up authentication with this model, see [model configuration for Hugging Face](../../models/huggingface.md). + +::: pydantic_ai.models.huggingface diff --git a/docs/api/providers.md b/docs/api/providers.md index 926cf8e8b..8e808185a 100644 --- a/docs/api/providers.md +++ b/docs/api/providers.md @@ -29,3 +29,5 @@ ::: pydantic_ai.providers.heroku.HerokuProvider ::: pydantic_ai.providers.openrouter.OpenRouterProvider + +::: pydantic_ai.providers.huggingface.HuggingFaceProvider diff --git a/docs/models/huggingface.md b/docs/models/huggingface.md index 8d10a7ea8..e99a77f00 100644 --- a/docs/models/huggingface.md +++ b/docs/models/huggingface.md @@ -1,9 +1,8 @@ # Hugging Face - ## Install -To use `HuggingFace`, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `huggingface` optional group: +To use `HuggingFaceModel`, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `huggingface` optional group: ```bash pip/uv-add "pydantic-ai-slim[huggingface]" @@ -11,17 +10,19 @@ pip/uv-add "pydantic-ai-slim[huggingface]" ## Configuration -To use `HuggingFaceModel` through their main API, go to [Inference Providers documentation](https://huggingface.co/docs/inference-providers/pricing) for all the details, and you can generate a Hugging Face Token here: https://huggingface.co/settings/tokens. +To use [HuggingFace](https://huggingface.co/) through their main API, go to +[Inference Providers documentation](https://huggingface.co/docs/inference-providers/pricing) for all the details, +and you can generate a Hugging Face access token here: https://huggingface.co/settings/tokens. -## Environment variable +## Hugging Face access token -Once you have a HuggingFace Token, you can set it as an environment variable: +Once you have a Hugging Face access token, you can set it as an environment variable: ```bash export HF_TOKEN='hf_token' ``` -You can then use `HuggingFaceModel` by name: +You can then use [`HuggingFaceModel`][pydantic_ai.models.huggingface.HuggingFaceModel] by name: ```python from pydantic_ai import Agent @@ -41,12 +42,15 @@ agent = Agent(model) ... ``` -By default, the `HuggingFaceModel` uses the `HuggingFaceProvider` that will select automatically the first of the inference providers (Cerebras, Together AI, Cohere..etc) available for the model, sorted by your preferred order in https://hf.co/settings/inference-providers. +By default, the [`HuggingFaceModel`][pydantic_ai.models.huggingface.HuggingFaceModel] uses the +[`HuggingFaceProvider`][pydantic_ai.providers.huggingface.HuggingFaceProvider] that will select automatically +the first of the inference providers (Cerebras, Together AI, Cohere..etc) available for the model, sorted by your +preferred order in https://hf.co/settings/inference-providers. ## Configure the provider If you want to pass parameters in code to the provider, you can programmatically instantiate the -[HuggingFaceProvider][pydantic_ai.providers.huggingface.HuggingFaceProvider] and pass it to the model: +[`HuggingFaceProvider`][pydantic_ai.providers.huggingface.HuggingFaceProvider] and pass it to the model: ```python from pydantic_ai import Agent @@ -58,9 +62,12 @@ agent = Agent(model) ... ``` -## Custom Hugging Face Client +## Custom Hugging Face client -`HuggingFaceProvider` also accepts a custom `AsyncInferenceClient` client via the `hf_client` parameter, so you can customise the `headers`, `bill_to` (billing to an HF organization you're a member of), `base_url` etc. as defined in the [Hugging Face Hub python library docs](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client). +[`HuggingFaceProvider`][pydantic_ai.providers.huggingface.HuggingFaceProvider] also accepts a custom +[`AsyncInferenceClient`][huggingface_hub.AsyncInferenceClient] client via the `hf_client` parameter, so you can customise +the `headers`, `bill_to` (billing to an HF organization you're a member of), `base_url` etc. as defined in the +[Hugging Face Hub python library docs](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client). ```python from huggingface_hub import AsyncInferenceClient diff --git a/mkdocs.yml b/mkdocs.yml index d750c29bb..55fd86384 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -81,6 +81,7 @@ nav: - api/models/gemini.md - api/models/google.md - api/models/groq.md + - api/models/huggingface.md - api/models/instrumented.md - api/models/mistral.md - api/models/test.md From adfc2548918e804f7896adece711ac56c9ea7178 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Wed, 2 Jul 2025 17:57:10 +0200 Subject: [PATCH 15/18] review suggestions --- docs/models/huggingface.md | 2 +- .../pydantic_ai/providers/huggingface.py | 28 ++++++-- tests/conftest.py | 4 +- .../test_hf_model_instructions.yaml | 66 ++++++++++++++++++- tests/models/test_huggingface.py | 15 +++-- tests/providers/test_huggingface.py | 2 +- 6 files changed, 98 insertions(+), 19 deletions(-) diff --git a/docs/models/huggingface.md b/docs/models/huggingface.md index e99a77f00..6425d76cb 100644 --- a/docs/models/huggingface.md +++ b/docs/models/huggingface.md @@ -57,7 +57,7 @@ from pydantic_ai import Agent from pydantic_ai.models.huggingface import HuggingFaceModel from pydantic_ai.providers.huggingface import HuggingFaceProvider -model = HuggingFaceModel('Qwen/Qwen3-235B-A22B', provider=HuggingFaceProvider(api_key='hf_token', provider='nebius')) +model = HuggingFaceModel('Qwen/Qwen3-235B-A22B', provider=HuggingFaceProvider(api_key='hf_token', provider_name='nebius')) agent = Agent(model) ... ``` diff --git a/pydantic_ai_slim/pydantic_ai/providers/huggingface.py b/pydantic_ai_slim/pydantic_ai/providers/huggingface.py index e18a60d16..8afb41591 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/providers/huggingface.py @@ -1,6 +1,7 @@ from __future__ import annotations as _annotations import os +from typing import overload from httpx import AsyncClient @@ -32,13 +33,26 @@ def base_url(self) -> str: def client(self) -> AsyncInferenceClient: return self._client + @overload + def __init__(self, *, base_url: str, api_key: str | None = None) -> None: ... + @overload + def __init__(self, *, provider_name: str, api_key: str | None = None) -> None: ... + @overload + def __init__(self, *, hf_client: AsyncInferenceClient, api_key: str | None = None) -> None: ... + @overload + def __init__(self, *, hf_client: AsyncInferenceClient, base_url: str, api_key: str | None = None) -> None: ... + @overload + def __init__(self, *, hf_client: AsyncInferenceClient, provider_name: str, api_key: str | None = None) -> None: ... + @overload + def __init__(self, *, api_key: str | None = None) -> None: ... + def __init__( self, base_url: str | None = None, api_key: str | None = None, hf_client: AsyncInferenceClient | None = None, http_client: AsyncClient | None = None, - provider: str | None = None, + provider_name: str | None = None, ) -> None: """Create a new Hugging Face provider. @@ -50,9 +64,9 @@ def __init__( [`AsyncInferenceClient`](https://huggingface.co/docs/huggingface_hub/v0.29.3/en/package_reference/inference_client#huggingface_hub.AsyncInferenceClient) client to use. If not provided, a new instance will be created. http_client: (currently ignored) An existing `httpx.AsyncClient` to use for making HTTP requests. - provider : Name of the provider to use for inference. available providers can be found in the [HF Inference Providers documentation](https://huggingface.co/docs/inference-providers/index#partners). + provider_name : Name of the provider to use for inference. available providers can be found in the [HF Inference Providers documentation](https://huggingface.co/docs/inference-providers/index#partners). defaults to "auto", which will select the first available provider for the model, the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers. - If `base_url` is passed, then `provider` is not used. + If `base_url` is passed, then `provider_name` is not used. """ api_key = api_key or os.environ.get('HF_TOKEN') @@ -63,12 +77,12 @@ def __init__( ) if http_client is not None: - raise ValueError('`http_client` is ignored for HuggingFace provider, please use `hf_client` instead') + raise ValueError('`http_client` is ignored for HuggingFace provider, please use `hf_client` instead.') - if base_url is not None and provider is not None: - raise ValueError('Cannot provide both `base_url` and `provider`') + if base_url is not None and provider_name is not None: + raise ValueError('Cannot provide both `base_url` and `provider_name`.') if hf_client is None: - self._client = AsyncInferenceClient(api_key=api_key, provider=provider, base_url=base_url) # type: ignore + self._client = AsyncInferenceClient(api_key=api_key, provider=provider_name, base_url=base_url) # type: ignore else: self._client = hf_client diff --git a/tests/conftest.py b/tests/conftest.py index 6cfc627bd..73c6e07cc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -294,7 +294,7 @@ def openrouter_api_key() -> str: @pytest.fixture(scope='session') def huggingface_api_key() -> str: - return os.getenv('HF_TOKEN', 'hf_token') or os.getenv('HUGGINGFACE_API_KEY', 'hf_token') + return os.getenv('HF_TOKEN', 'hf_token') @pytest.fixture(scope='session') @@ -428,7 +428,7 @@ def model( return HuggingFaceModel( 'Qwen/Qwen2.5-72B-Instruct', - provider=HuggingFaceProvider(provider='nebius', api_key=huggingface_api_key), + provider=HuggingFaceProvider(provider_name='nebius', api_key=huggingface_api_key), ) else: raise ValueError(f'Unknown model: {request.param}') diff --git a/tests/models/cassettes/test_huggingface/test_hf_model_instructions.yaml b/tests/models/cassettes/test_huggingface/test_hf_model_instructions.yaml index 11bcb7596..d8a5ee07e 100644 --- a/tests/models/cassettes/test_huggingface/test_hf_model_instructions.yaml +++ b/tests/models/cassettes/test_huggingface/test_hf_model_instructions.yaml @@ -1,4 +1,66 @@ interactions: +- request: + body: null + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + method: GET + uri: https://huggingface.co/api/models/Qwen/Qwen2.5-72B-Instruct?expand=inferenceProviderMapping + response: + headers: + access-control-allow-origin: + - https://huggingface.co + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '701' + content-type: + - application/json; charset=utf-8 + cross-origin-opener-policy: + - same-origin + etag: + - W/"2bd-diYmxjldwbIbFgWNRPBqJ3SEIak" + referrer-policy: + - strict-origin-when-cross-origin + vary: + - Origin + parsed_body: + _id: 66e81cefd1b1391042d0e47e + id: Qwen/Qwen2.5-72B-Instruct + inferenceProviderMapping: + featherless-ai: + providerId: Qwen/Qwen2.5-72B-Instruct + status: live + task: conversational + fireworks-ai: + providerId: accounts/fireworks/models/qwen2p5-72b-instruct + status: live + task: conversational + hyperbolic: + providerId: Qwen/Qwen2.5-72B-Instruct + status: live + task: conversational + nebius: + providerId: Qwen/Qwen2.5-72B-Instruct-fast + status: live + task: conversational + novita: + providerId: qwen/qwen-2.5-72b-instruct + status: live + task: conversational + together: + providerId: Qwen/Qwen2.5-72B-Instruct-Turbo + status: live + task: conversational + status: + code: 200 + message: OK - request: body: null headers: {} @@ -40,8 +102,8 @@ interactions: role: assistant tool_calls: [] stop_reason: null - created: 1749475551 - id: chatcmpl-6fa46f85f4f04beda9c936d5996b22a8 + created: 1751470757 + id: chatcmpl-b3936940372c481b8d886e596dc75524 model: Qwen/Qwen2.5-72B-Instruct-fast object: chat.completion prompt_logprobs: null diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py index 384328adf..cae1e2bfe 100644 --- a/tests/models/test_huggingface.py +++ b/tests/models/test_huggingface.py @@ -125,7 +125,8 @@ async def test_simple_completion(allow_model_requests: None): c = completion_message(ChatCompletionInputMessage(content='world', role='assistant')) # type:ignore mock_client = MockHuggingFace.create_mock(c) model = HuggingFaceModel( - 'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider='nebius', hf_client=mock_client, api_key='x') + 'Qwen/Qwen2.5-72B-Instruct', + provider=HuggingFaceProvider(provider_name='nebius', hf_client=mock_client, api_key='x'), ) agent = Agent(model) @@ -148,7 +149,8 @@ async def test_request_simple_usage(allow_model_requests: None): c = completion_message(ChatCompletionInputMessage(content='world', role='assistant')) # type:ignore mock_client = MockHuggingFace.create_mock(c) model = HuggingFaceModel( - 'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider='nebius', hf_client=mock_client, api_key='x') + 'Qwen/Qwen2.5-72B-Instruct', + provider=HuggingFaceProvider(provider_name='nebius', hf_client=mock_client, api_key='x'), ) agent = Agent(model) @@ -181,7 +183,8 @@ async def test_request_structured_response(allow_model_requests: None): mock_client = MockHuggingFace.create_mock(c) model = HuggingFaceModel( - 'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider='nebius', hf_client=mock_client, api_key='x') + 'Qwen/Qwen2.5-72B-Instruct', + provider=HuggingFaceProvider(provider_name='nebius', hf_client=mock_client, api_key='x'), ) agent = Agent(model, output_type=list[int]) @@ -652,7 +655,7 @@ def test_model_status_error(allow_model_requests: None) -> None: @pytest.mark.vcr() async def test_request_simple_success_with_vcr(allow_model_requests: None, huggingface_api_key: str): m = HuggingFaceModel( - 'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider='nebius', api_key=huggingface_api_key) + 'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider_name='nebius', api_key=huggingface_api_key) ) agent = Agent(m) result = await agent.run('hello') @@ -664,7 +667,7 @@ async def test_request_simple_success_with_vcr(allow_model_requests: None, huggi @pytest.mark.vcr() async def test_hf_model_instructions(allow_model_requests: None, huggingface_api_key: str): m = HuggingFaceModel( - 'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider='nebius', api_key=huggingface_api_key) + 'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider_name='nebius', api_key=huggingface_api_key) ) def simple_instructions(ctx: RunContext): @@ -684,7 +687,7 @@ def simple_instructions(ctx: RunContext): usage=Usage(requests=1, request_tokens=26, response_tokens=2, total_tokens=28), model_name='Qwen/Qwen2.5-72B-Instruct-fast', timestamp=IsDatetime(), - vendor_id='chatcmpl-6fa46f85f4f04beda9c936d5996b22a8', + vendor_id='chatcmpl-b3936940372c481b8d886e596dc75524', ), ] ) diff --git a/tests/providers/test_huggingface.py b/tests/providers/test_huggingface.py index 970c9d636..944d418a0 100644 --- a/tests/providers/test_huggingface.py +++ b/tests/providers/test_huggingface.py @@ -44,7 +44,7 @@ def test_huggingface_provider_pass_http_client() -> None: ValueError, match=re.escape('`http_client` is ignored for HuggingFace provider, please use `hf_client` instead'), ): - HuggingFaceProvider(http_client=http_client, api_key='api-key') + HuggingFaceProvider(http_client=http_client, api_key='api-key') # type: ignore def test_huggingface_provider_pass_hf_client() -> None: From e4af59eee94bb68253b11c4a2ee5db0a050daac1 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Mon, 7 Jul 2025 17:27:17 +0200 Subject: [PATCH 16/18] more tests --- tests/providers/test_huggingface.py | 69 +++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/tests/providers/test_huggingface.py b/tests/providers/test_huggingface.py index 944d418a0..9c6074af7 100644 --- a/tests/providers/test_huggingface.py +++ b/tests/providers/test_huggingface.py @@ -1,6 +1,7 @@ from __future__ import annotations as _annotations import re +from unittest.mock import MagicMock, Mock, patch import httpx import pytest @@ -59,3 +60,71 @@ def test_hf_provider_with_base_url() -> None: hf_client=AsyncInferenceClient(base_url='https://router.huggingface.co/nebius/v1'), api_key='test-api-key' ) assert provider.base_url == 'https://router.huggingface.co/nebius/v1' + + +def test_huggingface_provider_properties(): + mock_client = Mock(spec=AsyncInferenceClient) + mock_client.model = 'test-model' + provider = HuggingFaceProvider(hf_client=mock_client) + assert provider.name == 'huggingface' + assert provider.base_url == 'test-model' + assert provider.client is mock_client + + +def test_huggingface_provider_init_api_key_error(monkeypatch: pytest.MonkeyPatch): + monkeypatch.delenv('HF_TOKEN', raising=False) + with pytest.raises(UserError, match='Set the `HF_TOKEN` environment variable'): + HuggingFaceProvider() + + +@patch('pydantic_ai.providers.huggingface.AsyncInferenceClient') +def test_huggingface_provider_init_api_key_from_env( + MockAsyncInferenceClient: MagicMock, monkeypatch: pytest.MonkeyPatch +): + monkeypatch.setenv('HF_TOKEN', 'env-key') + HuggingFaceProvider() + MockAsyncInferenceClient.assert_called_with(api_key='env-key', provider=None, base_url=None) + + +@patch('pydantic_ai.providers.huggingface.AsyncInferenceClient') +def test_huggingface_provider_init_api_key_from_arg( + MockAsyncInferenceClient: MagicMock, monkeypatch: pytest.MonkeyPatch +): + monkeypatch.setenv('HF_TOKEN', 'env-key') + HuggingFaceProvider(api_key='arg-key') + MockAsyncInferenceClient.assert_called_with(api_key='arg-key', provider=None, base_url=None) + + +def test_huggingface_provider_init_http_client_error(): + with pytest.raises(ValueError, match='`http_client` is ignored'): + HuggingFaceProvider(api_key='key', http_client=Mock()) # type: ignore[call-overload] + + +def test_huggingface_provider_init_base_url_and_provider_name_error(): + with pytest.raises(ValueError, match='Cannot provide both `base_url` and `provider_name`'): + HuggingFaceProvider(api_key='key', base_url='url', provider_name='provider') # type: ignore[call-overload] + + +def test_huggingface_provider_init_with_hf_client(): + mock_client = Mock(spec=AsyncInferenceClient) + provider = HuggingFaceProvider(hf_client=mock_client) + assert provider.client is mock_client + + +@patch('pydantic_ai.providers.huggingface.AsyncInferenceClient') +def test_huggingface_provider_init_without_hf_client(MockAsyncInferenceClient: MagicMock): + provider = HuggingFaceProvider(api_key='key') + assert provider.client is MockAsyncInferenceClient.return_value + MockAsyncInferenceClient.assert_called_with(api_key='key', provider=None, base_url=None) + + +@patch('pydantic_ai.providers.huggingface.AsyncInferenceClient') +def test_huggingface_provider_init_with_provider_name(MockAsyncInferenceClient: MagicMock): + HuggingFaceProvider(api_key='key', provider_name='test-provider') + MockAsyncInferenceClient.assert_called_once_with(api_key='key', provider='test-provider', base_url=None) + + +@patch('pydantic_ai.providers.huggingface.AsyncInferenceClient') +def test_huggingface_provider_init_with_base_url(MockAsyncInferenceClient: MagicMock): + HuggingFaceProvider(api_key='key', base_url='test-url') + MockAsyncInferenceClient.assert_called_once_with(api_key='key', provider=None, base_url='test-url') From cd76d7871003a8e37a9bcccbd9c5af1fb492cd7a Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Mon, 7 Jul 2025 17:34:06 +0200 Subject: [PATCH 17/18] fix test --- tests/providers/test_huggingface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/providers/test_huggingface.py b/tests/providers/test_huggingface.py index 9c6074af7..858e049cb 100644 --- a/tests/providers/test_huggingface.py +++ b/tests/providers/test_huggingface.py @@ -107,7 +107,7 @@ def test_huggingface_provider_init_base_url_and_provider_name_error(): def test_huggingface_provider_init_with_hf_client(): mock_client = Mock(spec=AsyncInferenceClient) - provider = HuggingFaceProvider(hf_client=mock_client) + provider = HuggingFaceProvider(hf_client=mock_client, api_key='key') assert provider.client is mock_client From 13ebbf9d5e4da1c128f358cecf28a2e1ad2b7fb6 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Mon, 7 Jul 2025 17:35:13 +0200 Subject: [PATCH 18/18] fix another test --- tests/providers/test_huggingface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/providers/test_huggingface.py b/tests/providers/test_huggingface.py index 858e049cb..c9d263ec6 100644 --- a/tests/providers/test_huggingface.py +++ b/tests/providers/test_huggingface.py @@ -65,7 +65,7 @@ def test_hf_provider_with_base_url() -> None: def test_huggingface_provider_properties(): mock_client = Mock(spec=AsyncInferenceClient) mock_client.model = 'test-model' - provider = HuggingFaceProvider(hf_client=mock_client) + provider = HuggingFaceProvider(hf_client=mock_client, api_key='test-api-key') assert provider.name == 'huggingface' assert provider.base_url == 'test-model' assert provider.client is mock_client