Skip to content

feat(inference): Add BagelNet provider for Python client integration #3189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class InferenceClient:
Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL.
provider (`str`, *optional*):
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`.
Name of the provider to use for inference. Can be `"bagelnet"`, `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`.
Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
If model is a URL or `base_url` is passed, then `provider` is not used.
token (`str`, *optional*):
Expand Down
5 changes: 5 additions & 0 deletions src/huggingface_hub/inference/_providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from huggingface_hub.utils import logging

from ._common import TaskProviderHelper, _fetch_inference_provider_mapping
from .bagelnet import BagelNetConversationalTask
from .black_forest_labs import BlackForestLabsTextToImageTask
from .cerebras import CerebrasConversationalTask
from .cohere import CohereConversationalTask
Expand Down Expand Up @@ -44,6 +45,7 @@


PROVIDER_T = Literal[
"bagelnet",
"black-forest-labs",
"cerebras",
"cohere",
Expand All @@ -65,6 +67,9 @@
PROVIDER_OR_POLICY_T = Union[PROVIDER_T, Literal["auto"]]

PROVIDERS: Dict[PROVIDER_T, Dict[str, TaskProviderHelper]] = {
"bagelnet": {
"conversational": BagelNetConversationalTask(),
},
"black-forest-labs": {
"text-to-image": BlackForestLabsTextToImageTask(),
},
Expand Down
1 change: 1 addition & 0 deletions src/huggingface_hub/inference/_providers/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
# provider_id="Qwen2.5-Coder-32B-Instruct",
# task="conversational",
# status="live")
"bagelnet": {},
"cerebras": {},
"cohere": {},
"fal-ai": {},
Expand Down
6 changes: 6 additions & 0 deletions src/huggingface_hub/inference/_providers/bagelnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from ._common import BaseConversationalTask


class BagelNetConversationalTask(BaseConversationalTask):
def __init__(self):
super().__init__(provider="bagelnet", base_url="https://api.bagel.net")
31 changes: 31 additions & 0 deletions tests/test_bagelnet_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest

from huggingface_hub.inference._providers.bagelnet import BagelNetConversationalTask


class TestBagelNetConversationalTask:
def test_init(self):
"""Test BagelNet provider initialization."""
task = BagelNetConversationalTask()
assert task.provider == "bagelnet"
assert task.base_url == "https://api.bagel.net"
assert task.task == "conversational"

def test_inheritance(self):
"""Test BagelNet inherits from BaseConversationalTask."""
from huggingface_hub.inference._providers._common import BaseConversationalTask

task = BagelNetConversationalTask()
assert isinstance(task, BaseConversationalTask)

def test_no_method_overrides(self):
"""Test that BagelNet uses default implementations (no overrides needed)."""
task = BagelNetConversationalTask()

# Should use default route
route = task._prepare_route("test_model", "test_key")
assert route == "/v1/chat/completions"

# Should use default base URL behavior
direct_url = task._prepare_base_url("sk-test-key") # Non-HF key
assert direct_url == "https://api.bagel.net"