Skip to content

Commit 975df7f

Browse files
pohneanhanouticelinagithub-actions[bot]
authored
✨ Support for Featherless.ai as inference provider (#3081)
* add featherless-ai provider * add featherless to the table * Apply style fixes * Update src/huggingface_hub/inference/_providers/featherless_ai.py Co-authored-by: célina <hanouticelina@gmail.com> * update featherless api provider + pass api_key to prepare_mapping_info function * Revert "Auxiliary commit to revert individual files from b479672" This reverts commit 761553ee31e10b02105c33969e2705f24fe9a8b4. * Update src/huggingface_hub/inference/_providers/featherless_ai.py --------- Co-authored-by: Celina Hanouti <hanouticelina@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 0ef2d68 commit 975df7f

File tree

6 files changed

+92
-32
lines changed

6 files changed

+92
-32
lines changed

docs/source/en/guides/inference.md

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -207,36 +207,36 @@ For more details, refer to the [Inference Providers pricing documentation](https
207207

208208
[`InferenceClient`]'s goal is to provide the easiest interface to run inference on Hugging Face models, on any provider. It has a simple API that supports the most common tasks. Here is a table showing which providers support which tasks:
209209

210-
| Task | Black Forest Labs | Cerebras | Cohere | fal-ai | Fireworks AI | HF Inference | Hyperbolic | Nebius AI Studio | Novita AI | Replicate | Sambanova | Together |
211-
| --------------------------------------------------- | ----------------- | -------- | ------ | ------ | ------------ | ------------ | ---------- | ---------------- | --------- | --------- | --------- | -------- |
212-
| [`~InferenceClient.audio_classification`] |||||||||||||
213-
| [`~InferenceClient.audio_to_audio`] |||||||||||||
214-
| [`~InferenceClient.automatic_speech_recognition`] |||||||||||||
215-
| [`~InferenceClient.chat_completion`] |||||||||||||
216-
| [`~InferenceClient.document_question_answering`] |||||||||||||
217-
| [`~InferenceClient.feature_extraction`] |||||||||||||
218-
| [`~InferenceClient.fill_mask`] |||||||||||||
219-
| [`~InferenceClient.image_classification`] |||||||||||||
220-
| [`~InferenceClient.image_segmentation`] |||||||||||||
221-
| [`~InferenceClient.image_to_image`] |||||||||||||
222-
| [`~InferenceClient.image_to_text`] |||||||||||||
223-
| [`~InferenceClient.object_detection`] |||||||||||||
224-
| [`~InferenceClient.question_answering`] |||||||||||||
225-
| [`~InferenceClient.sentence_similarity`] |||||||||||||
226-
| [`~InferenceClient.summarization`] |||||||||||||
227-
| [`~InferenceClient.table_question_answering`] |||||||||||||
228-
| [`~InferenceClient.text_classification`] |||||||||||||
229-
| [`~InferenceClient.text_generation`] |||||||||||||
230-
| [`~InferenceClient.text_to_image`] |||||||||||||
231-
| [`~InferenceClient.text_to_speech`] |||||||||||||
232-
| [`~InferenceClient.text_to_video`] |||||||||||||
233-
| [`~InferenceClient.tabular_classification`] |||||||||||||
234-
| [`~InferenceClient.tabular_regression`] |||||||||||||
235-
| [`~InferenceClient.token_classification`] |||||||||||||
236-
| [`~InferenceClient.translation`] |||||||||||||
237-
| [`~InferenceClient.visual_question_answering`] |||||||||||||
238-
| [`~InferenceClient.zero_shot_image_classification`] |||||||||||||
239-
| [`~InferenceClient.zero_shot_classification`] |||||||||||||
210+
| Task | Black Forest Labs | Cerebras | Cohere | fal-ai | Featherless AI | Fireworks AI | HF Inference | Hyperbolic | Nebius AI Studio | Novita AI | Replicate | Sambanova | Together |
211+
| --------------------------------------------------- | ----------------- | -------- | ------ | ------ | -------------- | ------------ | ------------ | ---------- | ---------------- | --------- | --------- | --------- | -------- |
212+
| [`~InferenceClient.audio_classification`] ||||| |||||||||
213+
| [`~InferenceClient.audio_to_audio`] ||||| |||||||||
214+
| [`~InferenceClient.automatic_speech_recognition`] ||||| |||||||||
215+
| [`~InferenceClient.chat_completion`] ||||| |||||||||
216+
| [`~InferenceClient.document_question_answering`] ||||| |||||||||
217+
| [`~InferenceClient.feature_extraction`] ||||| |||||||||
218+
| [`~InferenceClient.fill_mask`] ||||| |||||||||
219+
| [`~InferenceClient.image_classification`] ||||| |||||||||
220+
| [`~InferenceClient.image_segmentation`] ||||| |||||||||
221+
| [`~InferenceClient.image_to_image`] ||||| |||||||||
222+
| [`~InferenceClient.image_to_text`] ||||| |||||||||
223+
| [`~InferenceClient.object_detection`] ||||| |||||||||
224+
| [`~InferenceClient.question_answering`] ||||| |||||||||
225+
| [`~InferenceClient.sentence_similarity`] ||||| |||||||||
226+
| [`~InferenceClient.summarization`] ||||| |||||||||
227+
| [`~InferenceClient.table_question_answering`] ||||| |||||||||
228+
| [`~InferenceClient.text_classification`] ||||| |||||||||
229+
| [`~InferenceClient.text_generation`] ||||| | ||||||||
230+
| [`~InferenceClient.text_to_image`] ||||| |||||||||
231+
| [`~InferenceClient.text_to_speech`] ||||| |||||||||
232+
| [`~InferenceClient.text_to_video`] ||||| |||||||||
233+
| [`~InferenceClient.tabular_classification`] ||||| |||||||||
234+
| [`~InferenceClient.tabular_regression`] ||||| |||||||||
235+
| [`~InferenceClient.token_classification`] ||||| |||||||||
236+
| [`~InferenceClient.translation`] ||||| |||||||||
237+
| [`~InferenceClient.visual_question_answering`] ||||| |||||||||
238+
| [`~InferenceClient.zero_shot_image_classification`] ||||| |||||||||
239+
| [`~InferenceClient.zero_shot_classification`] ||||| |||||||||
240240

241241
<Tip>
242242

src/huggingface_hub/inference/_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ class InferenceClient:
134134
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
135135
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
136136
provider (`str`, *optional*):
137-
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`.
137+
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`.
138138
Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
139139
If model is a URL or `base_url` is passed, then `provider` is not used.
140140
token (`str`, *optional*):

src/huggingface_hub/inference/_generated/_async_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ class AsyncInferenceClient:
122122
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
123123
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
124124
provider (`str`, *optional*):
125-
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`.
125+
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`.
126126
Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
127127
If model is a URL or `base_url` is passed, then `provider` is not used.
128128
token (`str`, *optional*):

src/huggingface_hub/inference/_providers/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from typing import Dict, Literal, Optional, Union
22

3+
from huggingface_hub.inference._providers.featherless_ai import (
4+
FeatherlessConversationalTask,
5+
FeatherlessTextGenerationTask,
6+
)
37
from huggingface_hub.utils import logging
48

59
from ._common import TaskProviderHelper, _fetch_inference_provider_mapping
@@ -42,6 +46,7 @@
4246
"cerebras",
4347
"cohere",
4448
"fal-ai",
49+
"featherless-ai",
4550
"fireworks-ai",
4651
"hf-inference",
4752
"hyperbolic",
@@ -72,6 +77,10 @@
7277
"text-to-speech": FalAITextToSpeechTask(),
7378
"text-to-video": FalAITextToVideoTask(),
7479
},
80+
"featherless-ai": {
81+
"conversational": FeatherlessConversationalTask(),
82+
"text-generation": FeatherlessTextGenerationTask(),
83+
},
7584
"fireworks-ai": {
7685
"conversational": FireworksAIConversationalTask(),
7786
},
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from typing import Any, Dict, Optional, Union
2+
3+
from huggingface_hub.hf_api import InferenceProviderMapping
4+
from huggingface_hub.inference._common import RequestParameters, _as_dict
5+
6+
from ._common import BaseConversationalTask, BaseTextGenerationTask, filter_none
7+
8+
9+
_PROVIDER = "featherless-ai"
10+
_BASE_URL = "https://api.featherless.ai"
11+
12+
13+
class FeatherlessTextGenerationTask(BaseTextGenerationTask):
14+
def __init__(self):
15+
super().__init__(provider=_PROVIDER, base_url=_BASE_URL)
16+
17+
def _prepare_payload_as_dict(
18+
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
19+
) -> Optional[Dict]:
20+
params = filter_none(parameters.copy())
21+
params["max_tokens"] = params.pop("max_new_tokens", None)
22+
23+
return {"prompt": inputs, **params, "model": provider_mapping_info.provider_id}
24+
25+
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
26+
output = _as_dict(response)["choices"][0]
27+
return {
28+
"generated_text": output["text"],
29+
"details": {
30+
"finish_reason": output.get("finish_reason"),
31+
"seed": output.get("seed"),
32+
},
33+
}
34+
35+
36+
class FeatherlessConversationalTask(BaseConversationalTask):
37+
def __init__(self):
38+
super().__init__(provider=_PROVIDER, base_url=_BASE_URL)

0 commit comments

Comments
 (0)