diff --git a/packages/inference/README.md b/packages/inference/README.md index 0ea60b2be7..2ffee4306b 100644 --- a/packages/inference/README.md +++ b/packages/inference/README.md @@ -63,6 +63,7 @@ Currently, we support the following providers: - [Cohere](https://cohere.com) - [Cerebras](https://cerebras.ai/) - [Groq](https://groq.com) +- [Swarmind](https://swarmind.ai) To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. The default value of the `provider` parameter is "auto", which will select the first of the providers available for the model, sorted by your preferred order in https://hf.co/settings/inference-providers. diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index 147fe08ec5..2026a4bade 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -12,6 +12,7 @@ import * as Novita from "../providers/novita.js"; import * as Nscale from "../providers/nscale.js"; import * as OpenAI from "../providers/openai.js"; import * as OvhCloud from "../providers/ovhcloud.js"; +import * as Swarmind from "../providers/swarmind.js" import type { AudioClassificationTaskHelper, AudioToAudioTaskHelper, @@ -147,6 +148,10 @@ export const PROVIDERS: Record Swarmind model ID here: + * + * https://huggingface.co/api/partners/swarmind/models + * + * This is a publicly available mapping. + * + * If you want to try to run inference for a new model locally before it's registered on huggingface.co, + * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes. + * + * - If you work at Swarmind and want to update this mapping, please use the model mapping API we provide on huggingface.co + * - If you're a community member and want to add a new supported HF model to Swarmind, please open an issue on the present repo + * and we will tag Swarmind team members. + * + * Thanks! + */ + +const SWARMIND_API_BASE_URL = "https://api.swarmind.ai/lai/private"; + +export class SwarmindTextGenerationTask extends BaseTextGenerationTask { + constructor() { + super("swarmind", SWARMIND_API_BASE_URL); + } + + override makeRoute(): string { + return "/v1/chat/completions"; + } +} + +export class SwarmindConversationalTask extends BaseConversationalTask { + constructor() { + super("swarmind", SWARMIND_API_BASE_URL); + } + + override makeRoute(): string { + return "/v1/chat/completions"; + } +} diff --git a/packages/inference/src/types.ts b/packages/inference/src/types.ts index 5d6be233d8..80423a4d0f 100644 --- a/packages/inference/src/types.ts +++ b/packages/inference/src/types.ts @@ -61,6 +61,7 @@ export const INFERENCE_PROVIDERS = [ "ovhcloud", "replicate", "sambanova", + "swarmind", "together", ] as const; diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index 0da191ea9d..804560c09b 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -2107,4 +2107,55 @@ describe.skip("InferenceClient", () => { }, TIMEOUT ); + describe.concurrent( + "Swarmind", + () => { + const client = new InferenceClient(env.HF_SWARMIND_KEY ?? "dummy"); + + HARDCODED_MODEL_INFERENCE_MAPPING["swarmind"] = { + "TheDrummer/Tiger-Gemma-9B-v1": { + hfModelId: "TheDrummer/Tiger-Gemma-9B-v1", + providerId: "tiger-gemma-9b-v1-i1", + status: "live", + task: "conversational", + }, + }; + + it("chatCompletion", async () => { + const res = await client.chatCompletion({ + model: "TheDrummer/Tiger-Gemma-9B-v1", + provider: "swarmind", + messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], + }); + if (res.choices && res.choices.length > 0) { + const completion = res.choices[0].message?.content; + expect(completion).toContain("two"); + } + }); + + it("chatCompletion stream", async () => { + const stream = client.chatCompletionStream({ + model: "TheDrummer/Tiger-Gemma-9B-v1", + provider: "swarmind", + messages: [{ role: "user", content: "Say 'this is a test'" }], + stream: true, + }) as AsyncGenerator; + + let fullResponse = ""; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + const content = chunk.choices[0].delta?.content; + if (content) { + fullResponse += content; + } + } + } + + // Verify we got a meaningful response + expect(fullResponse).toBeTruthy(); + expect(fullResponse.length).toBeGreaterThan(0); + }); + }, + TIMEOUT + ); });