From 967ab715988d6702a1b9e9ad6765b78eefa89c4f Mon Sep 17 00:00:00 2001 From: Vlad Ryzhkov Date: Mon, 30 Jun 2025 16:12:57 -0600 Subject: [PATCH] Add BagelNet provider for Python client integration - Add BagelNetConversationalTask provider implementation - Register provider in PROVIDER_T and PROVIDERS with alphabetical ordering - Update InferenceClient docstring to document bagelnet provider - Add bagelnet entry to HARDCODED_MODEL_INFERENCE_MAPPING - Add unit tests for BagelNet provider Implements HuggingFace Inference Providers Step 5: Python client integration Follows patterns from cerebras.py and other conversational providers --- src/huggingface_hub/inference/_client.py | 2 +- .../inference/_providers/__init__.py | 5 +++ .../inference/_providers/_common.py | 1 + .../inference/_providers/bagelnet.py | 6 ++++ tests/test_bagelnet_provider.py | 31 +++++++++++++++++++ 5 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 src/huggingface_hub/inference/_providers/bagelnet.py create mode 100644 tests/test_bagelnet_provider.py diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index 3439dafd89..ab973e09fd 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -134,7 +134,7 @@ class InferenceClient: path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) documentation for details). When passing a URL as `model`, the client will not append any suffix path to it. provider (`str`, *optional*): - Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`. + Name of the provider to use for inference. Can be `"bagelnet"`, `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`. Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers. If model is a URL or `base_url` is passed, then `provider` is not used. token (`str`, *optional*): diff --git a/src/huggingface_hub/inference/_providers/__init__.py b/src/huggingface_hub/inference/_providers/__init__.py index 8d73b837fc..5559c7c46b 100644 --- a/src/huggingface_hub/inference/_providers/__init__.py +++ b/src/huggingface_hub/inference/_providers/__init__.py @@ -7,6 +7,7 @@ from huggingface_hub.utils import logging from ._common import TaskProviderHelper, _fetch_inference_provider_mapping +from .bagelnet import BagelNetConversationalTask from .black_forest_labs import BlackForestLabsTextToImageTask from .cerebras import CerebrasConversationalTask from .cohere import CohereConversationalTask @@ -43,6 +44,7 @@ PROVIDER_T = Literal[ + "bagelnet", "black-forest-labs", "cerebras", "cohere", @@ -64,6 +66,9 @@ PROVIDER_OR_POLICY_T = Union[PROVIDER_T, Literal["auto"]] PROVIDERS: Dict[PROVIDER_T, Dict[str, TaskProviderHelper]] = { + "bagelnet": { + "conversational": BagelNetConversationalTask(), + }, "black-forest-labs": { "text-to-image": BlackForestLabsTextToImageTask(), }, diff --git a/src/huggingface_hub/inference/_providers/_common.py b/src/huggingface_hub/inference/_providers/_common.py index 8ce62d56ea..d507092d32 100644 --- a/src/huggingface_hub/inference/_providers/_common.py +++ b/src/huggingface_hub/inference/_providers/_common.py @@ -22,6 +22,7 @@ # provider_id="Qwen2.5-Coder-32B-Instruct", # task="conversational", # status="live") + "bagelnet": {}, "cerebras": {}, "cohere": {}, "fal-ai": {}, diff --git a/src/huggingface_hub/inference/_providers/bagelnet.py b/src/huggingface_hub/inference/_providers/bagelnet.py new file mode 100644 index 0000000000..4b71b991b5 --- /dev/null +++ b/src/huggingface_hub/inference/_providers/bagelnet.py @@ -0,0 +1,6 @@ +from ._common import BaseConversationalTask + + +class BagelNetConversationalTask(BaseConversationalTask): + def __init__(self): + super().__init__(provider="bagelnet", base_url="https://api.bagel.net") \ No newline at end of file diff --git a/tests/test_bagelnet_provider.py b/tests/test_bagelnet_provider.py new file mode 100644 index 0000000000..7f3d33dd41 --- /dev/null +++ b/tests/test_bagelnet_provider.py @@ -0,0 +1,31 @@ +import pytest + +from huggingface_hub.inference._providers.bagelnet import BagelNetConversationalTask + + +class TestBagelNetConversationalTask: + def test_init(self): + """Test BagelNet provider initialization.""" + task = BagelNetConversationalTask() + assert task.provider == "bagelnet" + assert task.base_url == "https://api.bagel.net" + assert task.task == "conversational" + + def test_inheritance(self): + """Test BagelNet inherits from BaseConversationalTask.""" + from huggingface_hub.inference._providers._common import BaseConversationalTask + + task = BagelNetConversationalTask() + assert isinstance(task, BaseConversationalTask) + + def test_no_method_overrides(self): + """Test that BagelNet uses default implementations (no overrides needed).""" + task = BagelNetConversationalTask() + + # Should use default route + route = task._prepare_route("test_model", "test_key") + assert route == "/v1/chat/completions" + + # Should use default base URL behavior + direct_url = task._prepare_base_url("sk-test-key") # Non-HF key + assert direct_url == "https://api.bagel.net" \ No newline at end of file