diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index e6bed1af67..14342693c1 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -132,7 +132,7 @@ class InferenceClient: Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2 arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL. provider (`str`, *optional*): - Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`. + Name of the provider to use for inference. Can be `"bagelnet"`, `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`. Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers. If model is a URL or `base_url` is passed, then `provider` is not used. token (`str`, *optional*): diff --git a/src/huggingface_hub/inference/_providers/__init__.py b/src/huggingface_hub/inference/_providers/__init__.py index 405087d485..ea5a6abfd9 100644 --- a/src/huggingface_hub/inference/_providers/__init__.py +++ b/src/huggingface_hub/inference/_providers/__init__.py @@ -7,6 +7,7 @@ from huggingface_hub.utils import logging from ._common import TaskProviderHelper, _fetch_inference_provider_mapping +from .bagelnet import BagelNetConversationalTask from .black_forest_labs import BlackForestLabsTextToImageTask from .cerebras import CerebrasConversationalTask from .cohere import CohereConversationalTask @@ -44,6 +45,7 @@ PROVIDER_T = Literal[ + "bagelnet", "black-forest-labs", "cerebras", "cohere", @@ -65,6 +67,9 @@ PROVIDER_OR_POLICY_T = Union[PROVIDER_T, Literal["auto"]] PROVIDERS: Dict[PROVIDER_T, Dict[str, TaskProviderHelper]] = { + "bagelnet": { + "conversational": BagelNetConversationalTask(), + }, "black-forest-labs": { "text-to-image": BlackForestLabsTextToImageTask(), }, diff --git a/src/huggingface_hub/inference/_providers/_common.py b/src/huggingface_hub/inference/_providers/_common.py index 7618dcf42d..0753f0f14a 100644 --- a/src/huggingface_hub/inference/_providers/_common.py +++ b/src/huggingface_hub/inference/_providers/_common.py @@ -22,6 +22,7 @@ # provider_id="Qwen2.5-Coder-32B-Instruct", # task="conversational", # status="live") + "bagelnet": {}, "cerebras": {}, "cohere": {}, "fal-ai": {}, diff --git a/src/huggingface_hub/inference/_providers/bagelnet.py b/src/huggingface_hub/inference/_providers/bagelnet.py new file mode 100644 index 0000000000..4b71b991b5 --- /dev/null +++ b/src/huggingface_hub/inference/_providers/bagelnet.py @@ -0,0 +1,6 @@ +from ._common import BaseConversationalTask + + +class BagelNetConversationalTask(BaseConversationalTask): + def __init__(self): + super().__init__(provider="bagelnet", base_url="https://api.bagel.net") \ No newline at end of file diff --git a/tests/test_bagelnet_provider.py b/tests/test_bagelnet_provider.py new file mode 100644 index 0000000000..7f3d33dd41 --- /dev/null +++ b/tests/test_bagelnet_provider.py @@ -0,0 +1,31 @@ +import pytest + +from huggingface_hub.inference._providers.bagelnet import BagelNetConversationalTask + + +class TestBagelNetConversationalTask: + def test_init(self): + """Test BagelNet provider initialization.""" + task = BagelNetConversationalTask() + assert task.provider == "bagelnet" + assert task.base_url == "https://api.bagel.net" + assert task.task == "conversational" + + def test_inheritance(self): + """Test BagelNet inherits from BaseConversationalTask.""" + from huggingface_hub.inference._providers._common import BaseConversationalTask + + task = BagelNetConversationalTask() + assert isinstance(task, BaseConversationalTask) + + def test_no_method_overrides(self): + """Test that BagelNet uses default implementations (no overrides needed).""" + task = BagelNetConversationalTask() + + # Should use default route + route = task._prepare_route("test_model", "test_key") + assert route == "/v1/chat/completions" + + # Should use default base URL behavior + direct_url = task._prepare_base_url("sk-test-key") # Non-HF key + assert direct_url == "https://api.bagel.net" \ No newline at end of file