Skip to content

Commit 44fa110

Browse files
fabienricFabien Ric
andauthored
Add OVHcloud as an inference provider (#1303)
### What Adds OVHcloud as an inference provider. ### Test Plan Added new tests for OVHcloud both with and without streaming. ### What Should Reviewers Focus On? I used the Cerebras PR as an example. --------- Co-authored-by: Fabien Ric <[email protected]>
1 parent c10c090 commit 44fa110

File tree

7 files changed

+193
-0
lines changed

7 files changed

+193
-0
lines changed

packages/inference/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ Currently, we support the following providers:
5454
- [Nebius](https://studio.nebius.ai)
5555
- [Novita](https://novita.ai/?utm_source=github_huggingface&utm_medium=github_readme&utm_campaign=link)
5656
- [Nscale](https://nscale.com)
57+
- [OVHcloud](https://endpoints.ai.cloud.ovh.net/)
5758
- [Replicate](https://replicate.com)
5859
- [Sambanova](https://sambanova.ai)
5960
- [Together](https://together.xyz)
@@ -84,6 +85,7 @@ Only a subset of models are supported when requesting third-party providers. You
8485
- [Hyperbolic supported models](https://huggingface.co/api/partners/hyperbolic/models)
8586
- [Nebius supported models](https://huggingface.co/api/partners/nebius/models)
8687
- [Nscale supported models](https://huggingface.co/api/partners/nscale/models)
88+
- [OVHcloud supported models](https://huggingface.co/api/partners/ovhcloud/models)
8789
- [Replicate supported models](https://huggingface.co/api/partners/replicate/models)
8890
- [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models)
8991
- [Together supported models](https://huggingface.co/api/partners/together/models)

packages/inference/src/lib/getProviderHelper.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import * as Nebius from "../providers/nebius";
1111
import * as Novita from "../providers/novita";
1212
import * as Nscale from "../providers/nscale";
1313
import * as OpenAI from "../providers/openai";
14+
import * as OvhCloud from "../providers/ovhcloud";
1415
import type {
1516
AudioClassificationTaskHelper,
1617
AudioToAudioTaskHelper,
@@ -126,6 +127,10 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
126127
openai: {
127128
conversational: new OpenAI.OpenAIConversationalTask(),
128129
},
130+
ovhcloud: {
131+
conversational: new OvhCloud.OvhCloudConversationalTask(),
132+
"text-generation": new OvhCloud.OvhCloudTextGenerationTask(),
133+
},
129134
replicate: {
130135
"text-to-image": new Replicate.ReplicateTextToImageTask(),
131136
"text-to-speech": new Replicate.ReplicateTextToSpeechTask(),

packages/inference/src/providers/consts.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
3232
novita: {},
3333
nscale: {},
3434
openai: {},
35+
ovhcloud: {},
3536
replicate: {},
3637
sambanova: {},
3738
together: {},
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/**
2+
* See the registered mapping of HF model ID => OVHcloud model ID here:
3+
*
4+
* https://huggingface.co/api/partners/ovhcloud/models
5+
*
6+
* This is a publicly available mapping.
7+
*
8+
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
9+
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
10+
*
11+
* - If you work at OVHcloud and want to update this mapping, please use the model mapping API we provide on huggingface.co
12+
* - If you're a community member and want to add a new supported HF model to OVHcloud, please open an issue on the present repo
13+
* and we will tag OVHcloud team members.
14+
*
15+
* Thanks!
16+
*/
17+
18+
import { BaseConversationalTask, BaseTextGenerationTask } from "./providerHelper";
19+
import type {
20+
ChatCompletionOutput,
21+
TextGenerationOutput,
22+
TextGenerationOutputFinishReason,
23+
} from "@huggingface/tasks";
24+
import { InferenceOutputError } from "../lib/InferenceOutputError";
25+
import type { BodyParams } from "../types";
26+
import { omit } from "../utils/omit";
27+
import type { TextGenerationInput } from "@huggingface/tasks";
28+
29+
const OVHCLOUD_API_BASE_URL = "https://oai.endpoints.kepler.ai.cloud.ovh.net";
30+
31+
interface OvhCloudTextCompletionOutput extends Omit<ChatCompletionOutput, "choices"> {
32+
choices: Array<{
33+
text: string;
34+
finish_reason: TextGenerationOutputFinishReason;
35+
logprobs: unknown;
36+
index: number;
37+
}>;
38+
}
39+
40+
export class OvhCloudConversationalTask extends BaseConversationalTask {
41+
constructor() {
42+
super("ovhcloud", OVHCLOUD_API_BASE_URL);
43+
}
44+
}
45+
46+
export class OvhCloudTextGenerationTask extends BaseTextGenerationTask {
47+
constructor() {
48+
super("ovhcloud", OVHCLOUD_API_BASE_URL);
49+
}
50+
51+
override preparePayload(params: BodyParams<TextGenerationInput>): Record<string, unknown> {
52+
return {
53+
model: params.model,
54+
...omit(params.args, ["inputs", "parameters"]),
55+
...(params.args.parameters
56+
? {
57+
max_tokens: (params.args.parameters as Record<string, unknown>).max_new_tokens,
58+
...omit(params.args.parameters as Record<string, unknown>, "max_new_tokens"),
59+
}
60+
: undefined),
61+
prompt: params.args.inputs,
62+
};
63+
}
64+
65+
override async getResponse(response: OvhCloudTextCompletionOutput): Promise<TextGenerationOutput> {
66+
if (
67+
typeof response === "object" &&
68+
"choices" in response &&
69+
Array.isArray(response?.choices) &&
70+
typeof response?.model === "string"
71+
) {
72+
const completion = response.choices[0];
73+
return {
74+
generated_text: completion.text,
75+
};
76+
}
77+
throw new InferenceOutputError("Expected OVHcloud text generation response format");
78+
}
79+
80+
}

packages/inference/src/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ export const INFERENCE_PROVIDERS = [
5151
"novita",
5252
"nscale",
5353
"openai",
54+
"ovhcloud",
5455
"replicate",
5556
"sambanova",
5657
"together",

packages/inference/test/InferenceClient.spec.ts

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1875,4 +1875,107 @@ describe.skip("InferenceClient", () => {
18751875
},
18761876
TIMEOUT
18771877
);
1878+
describe.concurrent(
1879+
"OVHcloud",
1880+
() => {
1881+
const client = new HfInference(env.HF_OVHCLOUD_KEY ?? "dummy");
1882+
1883+
HARDCODED_MODEL_INFERENCE_MAPPING["ovhcloud"] = {
1884+
"meta-llama/llama-3.1-8b-instruct": {
1885+
hfModelId: "meta-llama/llama-3.1-8b-instruct",
1886+
providerId: "Llama-3.1-8B-Instruct",
1887+
status: "live",
1888+
task: "conversational",
1889+
},
1890+
};
1891+
1892+
it("chatCompletion", async () => {
1893+
const res = await client.chatCompletion({
1894+
model: "meta-llama/llama-3.1-8b-instruct",
1895+
provider: "ovhcloud",
1896+
messages: [{ role: "user", content: "A, B, C, " }],
1897+
seed: 42,
1898+
temperature: 0,
1899+
top_p: 0.01,
1900+
max_tokens: 1,
1901+
});
1902+
expect(res.choices && res.choices.length > 0);
1903+
const completion = res.choices[0].message?.content;
1904+
expect(completion).toContain("D");
1905+
});
1906+
1907+
it("chatCompletion stream", async () => {
1908+
const stream = client.chatCompletionStream({
1909+
model: "meta-llama/llama-3.1-8b-instruct",
1910+
provider: "ovhcloud",
1911+
messages: [{ role: "user", content: "A, B, C, " }],
1912+
stream: true,
1913+
seed: 42,
1914+
temperature: 0,
1915+
top_p: 0.01,
1916+
max_tokens: 1,
1917+
}) as AsyncGenerator<ChatCompletionStreamOutput>;
1918+
1919+
let fullResponse = "";
1920+
for await (const chunk of stream) {
1921+
if (chunk.choices && chunk.choices.length > 0) {
1922+
const content = chunk.choices[0].delta?.content;
1923+
if (content) {
1924+
fullResponse += content;
1925+
}
1926+
}
1927+
}
1928+
1929+
// Verify we got a meaningful response
1930+
expect(fullResponse).toBeTruthy();
1931+
expect(fullResponse).toContain("D");
1932+
});
1933+
1934+
it("textGeneration", async () => {
1935+
const res = await client.textGeneration({
1936+
model: "meta-llama/llama-3.1-8b-instruct",
1937+
provider: "ovhcloud",
1938+
inputs: "A B C ",
1939+
parameters: {
1940+
seed: 42,
1941+
temperature: 0,
1942+
top_p: 0.01,
1943+
max_new_tokens: 1,
1944+
},
1945+
});
1946+
expect(res.generated_text.length > 0);
1947+
expect(res.generated_text).toContain("D");
1948+
});
1949+
1950+
it("textGeneration stream", async () => {
1951+
const stream = client.textGenerationStream({
1952+
model: "meta-llama/llama-3.1-8b-instruct",
1953+
provider: "ovhcloud",
1954+
inputs: "A B C ",
1955+
stream: true,
1956+
parameters: {
1957+
seed: 42,
1958+
temperature: 0,
1959+
top_p: 0.01,
1960+
max_new_tokens: 1,
1961+
},
1962+
}) as AsyncGenerator<ChatCompletionStreamOutput>;
1963+
1964+
let fullResponse = "";
1965+
for await (const chunk of stream) {
1966+
if (chunk.choices && chunk.choices.length > 0) {
1967+
const content = chunk.choices[0].text;
1968+
if (content) {
1969+
fullResponse += content;
1970+
}
1971+
}
1972+
}
1973+
1974+
// Verify we got a meaningful response
1975+
expect(fullResponse).toBeTruthy();
1976+
expect(fullResponse).toContain("D");
1977+
});
1978+
},
1979+
TIMEOUT
1980+
);
18781981
});

packages/tasks/src/inference-providers.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ const INFERENCE_PROVIDERS = [
77
"fireworks-ai",
88
"hf-inference",
99
"hyperbolic",
10+
"ovhcloud",
1011
"replicate",
1112
"sambanova",
1213
"together",

0 commit comments

Comments
 (0)