Skip to content

Commit 62137bd

Browse files
authored
[InferenceSnippets] Add endpointUrl option (#1521)
Add an option to pass a custom endpoint URL for inference snippets generation. ```ts snippets.getInferenceSnippets( model, selectedProvider, { hfModelId: ..., providerId: ..., status: "live", task: pipeline, }, { streaming, endpointURL: "http://localhost:8080/v1", } ) ``` cc @gary149
1 parent a444bd0 commit 62137bd

File tree

18 files changed

+176
-12
lines changed

18 files changed

+176
-12
lines changed

packages/inference/README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -651,9 +651,10 @@ You can use any Chat Completion API-compatible provider with the `chatCompletion
651651
```typescript
652652
// Chat Completion Example
653653
const MISTRAL_KEY = process.env.MISTRAL_KEY;
654-
const hf = new InferenceClient(MISTRAL_KEY);
655-
const ep = hf.endpoint("https://api.mistral.ai");
656-
const stream = ep.chatCompletionStream({
654+
const hf = new InferenceClient(MISTRAL_KEY, {
655+
endpointUrl: "https://api.mistral.ai",
656+
});
657+
const stream = hf.chatCompletionStream({
657658
model: "mistral-tiny",
658659
messages: [{ role: "user", content: "Complete the equation one + one = , just the answer" }],
659660
});

packages/inference/src/snippets/getInferenceSnippets.ts

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ export type InferenceSnippetOptions = {
1818
streaming?: boolean;
1919
billTo?: string;
2020
accessToken?: string;
21-
directRequest?: boolean;
21+
directRequest?: boolean; // to bypass HF routing and call the provider directly
22+
endpointUrl?: string; // to call a local endpoint directly
2223
} & Record<string, unknown>;
2324

2425
const PYTHON_CLIENTS = ["huggingface_hub", "fal_client", "requests", "openai"] as const;
@@ -53,6 +54,7 @@ interface TemplateParams {
5354
methodName?: string; // specific to snippetBasic
5455
importBase64?: boolean; // specific to snippetImportRequests
5556
importJson?: boolean; // specific to snippetImportRequests
57+
endpointUrl?: string;
5658
}
5759

5860
// Helpers to find + load templates
@@ -172,6 +174,7 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
172174
{
173175
accessToken: accessTokenOrPlaceholder,
174176
provider,
177+
endpointUrl: opts?.endpointUrl,
175178
...inputs,
176179
} as RequestArgs,
177180
inferenceProviderMapping,
@@ -217,6 +220,7 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
217220
provider,
218221
providerModelId: providerModelId ?? model.id,
219222
billTo: opts?.billTo,
223+
endpointUrl: opts?.endpointUrl,
220224
};
221225

222226
/// Iterate over clients => check if a snippet exists => generate
@@ -265,7 +269,14 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
265269

266270
/// Replace access token placeholder
267271
if (snippet.includes(placeholder)) {
268-
snippet = replaceAccessTokenPlaceholder(opts?.directRequest, placeholder, snippet, language, provider);
272+
snippet = replaceAccessTokenPlaceholder(
273+
opts?.directRequest,
274+
placeholder,
275+
snippet,
276+
language,
277+
provider,
278+
opts?.endpointUrl
279+
);
269280
}
270281

271282
/// Snippet is ready!
@@ -444,21 +455,24 @@ function replaceAccessTokenPlaceholder(
444455
placeholder: string,
445456
snippet: string,
446457
language: InferenceSnippetLanguage,
447-
provider: InferenceProviderOrPolicy
458+
provider: InferenceProviderOrPolicy,
459+
endpointUrl?: string
448460
): string {
449461
// If "opts.accessToken" is not set, the snippets are generated with a placeholder.
450462
// Once snippets are rendered, we replace the placeholder with code to fetch the access token from an environment variable.
451463

452464
// Determine if HF_TOKEN or specific provider token should be used
453465
const useHfToken =
454-
provider == "hf-inference" || // hf-inference provider => use $HF_TOKEN
455-
(!directRequest && // if explicit directRequest => use provider-specific token
456-
(!snippet.includes("https://") || // no URL provided => using a client => use $HF_TOKEN
457-
snippet.includes("https://router.huggingface.co"))); // explicit routed request => use $HF_TOKEN
458-
466+
!endpointUrl && // custom endpointUrl => use a generic API_TOKEN
467+
(provider == "hf-inference" || // hf-inference provider => use $HF_TOKEN
468+
(!directRequest && // if explicit directRequest => use provider-specific token
469+
(!snippet.includes("https://") || // no URL provided => using a client => use $HF_TOKEN
470+
snippet.includes("https://router.huggingface.co")))); // explicit routed request => use $HF_TOKEN
459471
const accessTokenEnvVar = useHfToken
460472
? "HF_TOKEN" // e.g. routed request or hf-inference
461-
: provider.toUpperCase().replace("-", "_") + "_API_KEY"; // e.g. "REPLICATE_API_KEY"
473+
: endpointUrl
474+
? "API_TOKEN"
475+
: provider.toUpperCase().replace("-", "_") + "_API_KEY"; // e.g. "REPLICATE_API_KEY"
462476

463477
// Replace the placeholder with the env variable
464478
if (language === "sh") {

packages/inference/src/snippets/templates/js/huggingface.js/basic.jinja

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ import { InferenceClient } from "@huggingface/inference";
33
const client = new InferenceClient("{{ accessToken }}");
44

55
const output = await client.{{ methodName }}({
6+
{% if endpointUrl %}
7+
endpointUrl: "{{ endpointUrl }}",
8+
{% endif %}
69
model: "{{ model.id }}",
710
inputs: {{ inputs.asObj.inputs }},
811
provider: "{{ provider }}",

packages/inference/src/snippets/templates/js/huggingface.js/basicAudio.jinja

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ const client = new InferenceClient("{{ accessToken }}");
55
const data = fs.readFileSync({{inputs.asObj.inputs}});
66

77
const output = await client.{{ methodName }}({
8+
{% if endpointUrl %}
9+
endpointUrl: "{{ endpointUrl }}",
10+
{% endif %}
811
data,
912
model: "{{ model.id }}",
1013
provider: "{{ provider }}",

packages/inference/src/snippets/templates/js/huggingface.js/basicImage.jinja

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ const client = new InferenceClient("{{ accessToken }}");
55
const data = fs.readFileSync({{inputs.asObj.inputs}});
66

77
const output = await client.{{ methodName }}({
8+
{% if endpointUrl %}
9+
endpointUrl: "{{ endpointUrl }}",
10+
{% endif %}
811
data,
912
model: "{{ model.id }}",
1013
provider: "{{ provider }}",

packages/inference/src/snippets/templates/js/huggingface.js/conversational.jinja

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ import { InferenceClient } from "@huggingface/inference";
33
const client = new InferenceClient("{{ accessToken }}");
44

55
const chatCompletion = await client.chatCompletion({
6+
{% if endpointUrl %}
7+
endpointUrl: "{{ endpointUrl }}",
8+
{% endif %}
69
provider: "{{ provider }}",
710
model: "{{ model.id }}",
811
{{ inputs.asTsString }}

packages/inference/src/snippets/templates/js/huggingface.js/conversationalStream.jinja

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ const client = new InferenceClient("{{ accessToken }}");
55
let out = "";
66

77
const stream = client.chatCompletionStream({
8+
{% if endpointUrl %}
9+
endpointUrl: "{{ endpointUrl }}",
10+
{% endif %}
811
provider: "{{ provider }}",
912
model: "{{ model.id }}",
1013
{{ inputs.asTsString }}

packages/inference/src/snippets/templates/js/huggingface.js/textToImage.jinja

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ import { InferenceClient } from "@huggingface/inference";
33
const client = new InferenceClient("{{ accessToken }}");
44

55
const image = await client.textToImage({
6+
{% if endpointUrl %}
7+
endpointUrl: "{{ endpointUrl }}",
8+
{% endif %}
69
provider: "{{ provider }}",
710
model: "{{ model.id }}",
811
inputs: {{ inputs.asObj.inputs }},

packages/inference/src/snippets/templates/js/huggingface.js/textToSpeech.jinja

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ import { InferenceClient } from "@huggingface/inference";
33
const client = new InferenceClient("{{ accessToken }}");
44

55
const audio = await client.textToSpeech({
6+
{% if endpointUrl %}
7+
endpointUrl: "{{ endpointUrl }}",
8+
{% endif %}
69
provider: "{{ provider }}",
710
model: "{{ model.id }}",
811
inputs: {{ inputs.asObj.inputs }},

packages/inference/src/snippets/templates/js/huggingface.js/textToVideo.jinja

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ import { InferenceClient } from "@huggingface/inference";
33
const client = new InferenceClient("{{ accessToken }}");
44

55
const video = await client.textToVideo({
6+
{% if endpointUrl %}
7+
endpointUrl: "{{ endpointUrl }}",
8+
{% endif %}
69
provider: "{{ provider }}",
710
model: "{{ model.id }}",
811
inputs: {{ inputs.asObj.inputs }},

packages/inference/src/snippets/templates/python/huggingface_hub/importInferenceClient.jinja

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from huggingface_hub import InferenceClient
22

33
client = InferenceClient(
4+
{% if endpointUrl %}
5+
base_url="{{ baseUrl }}",
6+
{% endif %}
47
provider="{{ provider }}",
58
api_key="{{ accessToken }}",
69
{% if billTo %}

packages/tasks-gen/scripts/generate-snippets-fixtures.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,18 @@ const TEST_CASES: {
9595
providers: ["hf-inference", "fireworks-ai"],
9696
opts: { streaming: true },
9797
},
98+
{
99+
testName: "conversational-llm-custom-endpoint",
100+
task: "conversational",
101+
model: {
102+
id: "meta-llama/Llama-3.1-8B-Instruct",
103+
pipeline_tag: "text-generation",
104+
tags: ["conversational"],
105+
inference: "",
106+
},
107+
providers: ["hf-inference"],
108+
opts: { endpointUrl: "http://localhost:8080/v1" },
109+
},
98110
{
99111
testName: "document-question-answering",
100112
task: "document-question-answering",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import { InferenceClient } from "@huggingface/inference";
2+
3+
const client = new InferenceClient(process.env.API_TOKEN);
4+
5+
const chatCompletion = await client.chatCompletion({
6+
endpointUrl: "http://localhost:8080/v1",
7+
provider: "hf-inference",
8+
model: "meta-llama/Llama-3.1-8B-Instruct",
9+
messages: [
10+
{
11+
role: "user",
12+
content: "What is the capital of France?",
13+
},
14+
],
15+
});
16+
17+
console.log(chatCompletion.choices[0].message);
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import { OpenAI } from "openai";
2+
3+
const client = new OpenAI({
4+
baseURL: "http://localhost:8080/v1",
5+
apiKey: process.env.API_TOKEN,
6+
});
7+
8+
const chatCompletion = await client.chat.completions.create({
9+
model: "meta-llama/Llama-3.1-8B-Instruct",
10+
messages: [
11+
{
12+
role: "user",
13+
content: "What is the capital of France?",
14+
},
15+
],
16+
});
17+
18+
console.log(chatCompletion.choices[0].message);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import os
2+
from huggingface_hub import InferenceClient
3+
4+
client = InferenceClient(
5+
base_url="http://localhost:8080/v1",
6+
provider="hf-inference",
7+
api_key=os.environ["API_TOKEN"],
8+
)
9+
10+
completion = client.chat.completions.create(
11+
model="meta-llama/Llama-3.1-8B-Instruct",
12+
messages=[
13+
{
14+
"role": "user",
15+
"content": "What is the capital of France?"
16+
}
17+
],
18+
)
19+
20+
print(completion.choices[0].message)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import os
2+
from openai import OpenAI
3+
4+
client = OpenAI(
5+
base_url="http://localhost:8080/v1",
6+
api_key=os.environ["API_TOKEN"],
7+
)
8+
9+
completion = client.chat.completions.create(
10+
model="meta-llama/Llama-3.1-8B-Instruct",
11+
messages=[
12+
{
13+
"role": "user",
14+
"content": "What is the capital of France?"
15+
}
16+
],
17+
)
18+
19+
print(completion.choices[0].message)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import os
2+
import requests
3+
4+
API_URL = "http://localhost:8080/v1/chat/completions"
5+
headers = {
6+
"Authorization": f"Bearer {os.environ['API_TOKEN']}",
7+
}
8+
9+
def query(payload):
10+
response = requests.post(API_URL, headers=headers, json=payload)
11+
return response.json()
12+
13+
response = query({
14+
"messages": [
15+
{
16+
"role": "user",
17+
"content": "What is the capital of France?"
18+
}
19+
],
20+
"model": "meta-llama/Llama-3.1-8B-Instruct"
21+
})
22+
23+
print(response["choices"][0]["message"])
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
curl http://localhost:8080/v1/chat/completions \
2+
-H "Authorization: Bearer $API_TOKEN" \
3+
-H 'Content-Type: application/json' \
4+
-d '{
5+
"messages": [
6+
{
7+
"role": "user",
8+
"content": "What is the capital of France?"
9+
}
10+
],
11+
"model": "meta-llama/Llama-3.1-8B-Instruct",
12+
"stream": false
13+
}'

0 commit comments

Comments
 (0)