Skip to content

Commit b32dcb6

Browse files
authored
More python inference snippets (#1028)
Given the size of the PR and the internal changes, I think it's best to review the changes commit per commit. This PR: - improves snippets for document-question-answering (0288149) - improves snippets for automated-speech-recognition (060cd21) - adds snippets for image-to-image (6234469) - fixes some base64 imports (0907d8f) I added tests for all the above. I did not add `InferenceClient` snippets for all tasks as it's time-consuming. I only added tasks based on the trending models listed on [huggingface.co/models](https://huggingface.co/models). We might add more tasks in the future but that's less a priority. Also includes some cleaning: - remove useless `"export ..."` from python snippets module (6d29991) - order snippet definitions alphabetically for easier retrieval (0900c99) --- **EDIT (05/03/2025):** rebased from main + adapted to take providers into account. The scope of this PR is more or less the same as before.
1 parent 4b0963d commit b32dcb6

24 files changed

+275
-78
lines changed

packages/inference/src/snippets/python.ts

Lines changed: 113 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,11 @@ const snippetImportInferenceClient = (accessToken: string, provider: SnippetInfe
4444
from huggingface_hub import InferenceClient
4545
4646
client = InferenceClient(
47-
provider="${provider}",
48-
api_key="${accessToken || "{API_TOKEN}"}"
47+
provider="${provider}",
48+
api_key="${accessToken || "{API_TOKEN}"}",
4949
)`;
5050

51-
export const snippetConversational = (
51+
const snippetConversational = (
5252
model: ModelDataMinimal,
5353
accessToken: string,
5454
provider: SnippetInferenceProvider,
@@ -89,7 +89,7 @@ stream = client.chat.completions.create(
8989
model="${model.id}",
9090
messages=messages,
9191
${configStr}
92-
stream=True
92+
stream=True,
9393
)
9494
9595
for chunk in stream:
@@ -159,7 +159,7 @@ print(completion.choices[0].message)`,
159159
}
160160
};
161161

162-
export const snippetZeroShotClassification = (model: ModelDataMinimal): InferenceSnippet[] => {
162+
const snippetZeroShotClassification = (model: ModelDataMinimal): InferenceSnippet[] => {
163163
return [
164164
{
165165
client: "requests",
@@ -176,12 +176,11 @@ output = query({
176176
];
177177
};
178178

179-
export const snippetZeroShotImageClassification = (model: ModelDataMinimal): InferenceSnippet[] => {
179+
const snippetZeroShotImageClassification = (model: ModelDataMinimal): InferenceSnippet[] => {
180180
return [
181181
{
182182
client: "requests",
183-
content: `\
184-
def query(data):
183+
content: `def query(data):
185184
with open(data["image_path"], "rb") as f:
186185
img = f.read()
187186
payload={
@@ -199,7 +198,7 @@ output = query({
199198
];
200199
};
201200

202-
export const snippetBasic = (
201+
const snippetBasic = (
203202
model: ModelDataMinimal,
204203
accessToken: string,
205204
provider: SnippetInferenceProvider
@@ -213,9 +212,8 @@ export const snippetBasic = (
213212
${snippetImportInferenceClient(accessToken, provider)}
214213
215214
result = client.${HFH_INFERENCE_CLIENT_METHODS[model.pipeline_tag]}(
216-
model="${model.id}",
217215
inputs=${getModelInputSnippet(model)},
218-
provider="${provider}",
216+
model="${model.id}",
219217
)
220218
221219
print(result)
@@ -237,7 +235,7 @@ output = query({
237235
];
238236
};
239237

240-
export const snippetFile = (model: ModelDataMinimal): InferenceSnippet[] => {
238+
const snippetFile = (model: ModelDataMinimal): InferenceSnippet[] => {
241239
return [
242240
{
243241
client: "requests",
@@ -253,7 +251,7 @@ output = query(${getModelInputSnippet(model)})`,
253251
];
254252
};
255253

256-
export const snippetTextToImage = (
254+
const snippetTextToImage = (
257255
model: ModelDataMinimal,
258256
accessToken: string,
259257
provider: SnippetInferenceProvider,
@@ -268,7 +266,7 @@ ${snippetImportInferenceClient(accessToken, provider)}
268266
# output is a PIL.Image object
269267
image = client.text_to_image(
270268
${getModelInputSnippet(model)},
271-
model="${model.id}"
269+
model="${model.id}",
272270
)`,
273271
},
274272
...(provider === "fal-ai"
@@ -312,7 +310,7 @@ image = Image.open(io.BytesIO(image_bytes))`,
312310
];
313311
};
314312

315-
export const snippetTextToVideo = (
313+
const snippetTextToVideo = (
316314
model: ModelDataMinimal,
317315
accessToken: string,
318316
provider: SnippetInferenceProvider
@@ -326,14 +324,14 @@ ${snippetImportInferenceClient(accessToken, provider)}
326324
327325
video = client.text_to_video(
328326
${getModelInputSnippet(model)},
329-
model="${model.id}"
327+
model="${model.id}",
330328
)`,
331329
},
332330
]
333331
: [];
334332
};
335333

336-
export const snippetTabular = (model: ModelDataMinimal): InferenceSnippet[] => {
334+
const snippetTabular = (model: ModelDataMinimal): InferenceSnippet[] => {
337335
return [
338336
{
339337
client: "requests",
@@ -349,7 +347,7 @@ response = query({
349347
];
350348
};
351349

352-
export const snippetTextToAudio = (model: ModelDataMinimal): InferenceSnippet[] => {
350+
const snippetTextToAudio = (model: ModelDataMinimal): InferenceSnippet[] => {
353351
// Transformers TTS pipeline and api-inference-community (AIC) pipeline outputs are diverged
354352
// with the latest update to inference-api (IA).
355353
// Transformers IA returns a byte object (wav file), whereas AIC returns wav and sampling_rate.
@@ -374,8 +372,7 @@ Audio(audio_bytes)`,
374372
return [
375373
{
376374
client: "requests",
377-
content: `\
378-
def query(payload):
375+
content: `def query(payload):
379376
response = requests.post(API_URL, headers=headers, json=payload)
380377
return response.json()
381378
@@ -390,26 +387,97 @@ Audio(audio, rate=sampling_rate)`,
390387
}
391388
};
392389

393-
export const snippetDocumentQuestionAnswering = (model: ModelDataMinimal): InferenceSnippet[] => {
390+
const snippetAutomaticSpeechRecognition = (
391+
model: ModelDataMinimal,
392+
accessToken: string,
393+
provider: SnippetInferenceProvider
394+
): InferenceSnippet[] => {
395+
return [
396+
{
397+
client: "huggingface_hub",
398+
content: `${snippetImportInferenceClient(accessToken, provider)}
399+
output = client.automatic_speech_recognition(${getModelInputSnippet(model)}, model="${model.id}")`,
400+
},
401+
snippetFile(model)[0],
402+
];
403+
};
404+
405+
const snippetDocumentQuestionAnswering = (
406+
model: ModelDataMinimal,
407+
accessToken: string,
408+
provider: SnippetInferenceProvider
409+
): InferenceSnippet[] => {
410+
const inputsAsStr = getModelInputSnippet(model) as string;
411+
const inputsAsObj = JSON.parse(inputsAsStr);
412+
394413
return [
414+
{
415+
client: "huggingface_hub",
416+
content: `${snippetImportInferenceClient(accessToken, provider)}
417+
output = client.document_question_answering(
418+
"${inputsAsObj.image}",
419+
question="${inputsAsObj.question}",
420+
model="${model.id}",
421+
)`,
422+
},
395423
{
396424
client: "requests",
397-
content: `\
398-
def query(payload):
425+
content: `def query(payload):
399426
with open(payload["image"], "rb") as f:
400427
img = f.read()
401-
payload["image"] = base64.b64encode(img).decode("utf-8")
428+
payload["image"] = base64.b64encode(img).decode("utf-8")
402429
response = requests.post(API_URL, headers=headers, json=payload)
403430
return response.json()
404431
405432
output = query({
406-
"inputs": ${getModelInputSnippet(model)},
433+
"inputs": ${inputsAsStr},
407434
})`,
408435
},
409436
];
410437
};
411438

412-
export const pythonSnippets: Partial<
439+
const snippetImageToImage = (
440+
model: ModelDataMinimal,
441+
accessToken: string,
442+
provider: SnippetInferenceProvider
443+
): InferenceSnippet[] => {
444+
const inputsAsStr = getModelInputSnippet(model) as string;
445+
const inputsAsObj = JSON.parse(inputsAsStr);
446+
447+
return [
448+
{
449+
client: "huggingface_hub",
450+
content: `${snippetImportInferenceClient(accessToken, provider)}
451+
# output is a PIL.Image object
452+
image = client.image_to_image(
453+
"${inputsAsObj.image}",
454+
prompt="${inputsAsObj.prompt}",
455+
model="${model.id}",
456+
)`,
457+
},
458+
{
459+
client: "requests",
460+
content: `def query(payload):
461+
with open(payload["inputs"], "rb") as f:
462+
img = f.read()
463+
payload["inputs"] = base64.b64encode(img).decode("utf-8")
464+
response = requests.post(API_URL, headers=headers, json=payload)
465+
return response.content
466+
467+
image_bytes = query({
468+
"inputs": "${inputsAsObj.image}",
469+
"parameters": {"prompt": "${inputsAsObj.prompt}"},
470+
})
471+
472+
# You can access the image with PIL.Image for example
473+
import io
474+
from PIL import Image
475+
image = Image.open(io.BytesIO(image_bytes))`,
476+
},
477+
];
478+
};
479+
480+
const pythonSnippets: Partial<
413481
Record<
414482
PipelineType,
415483
(
@@ -435,7 +503,7 @@ export const pythonSnippets: Partial<
435503
"image-text-to-text": snippetConversational,
436504
"fill-mask": snippetBasic,
437505
"sentence-similarity": snippetBasic,
438-
"automatic-speech-recognition": snippetFile,
506+
"automatic-speech-recognition": snippetAutomaticSpeechRecognition,
439507
"text-to-image": snippetTextToImage,
440508
"text-to-video": snippetTextToVideo,
441509
"text-to-speech": snippetTextToAudio,
@@ -449,6 +517,7 @@ export const pythonSnippets: Partial<
449517
"image-segmentation": snippetFile,
450518
"document-question-answering": snippetDocumentQuestionAnswering,
451519
"image-to-text": snippetFile,
520+
"image-to-image": snippetImageToImage,
452521
"zero-shot-image-classification": snippetZeroShotImageClassification,
453522
};
454523

@@ -471,17 +540,24 @@ export function getPythonInferenceSnippet(
471540
return snippets.map((snippet) => {
472541
return {
473542
...snippet,
474-
content:
475-
snippet.client === "requests"
476-
? `\
477-
import requests
478-
479-
API_URL = "${openAIbaseUrl(provider)}"
480-
headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}}
481-
482-
${snippet.content}`
483-
: snippet.content,
543+
content: addImportsToSnippet(snippet.content, model, accessToken),
484544
};
485545
});
486546
}
487547
}
548+
549+
const addImportsToSnippet = (snippet: string, model: ModelDataMinimal, accessToken: string): string => {
550+
if (snippet.includes("requests")) {
551+
snippet = `import requests
552+
553+
API_URL = "https://router.huggingface.co/hf-inference/models/${model.id}"
554+
headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}}
555+
556+
${snippet}`;
557+
}
558+
if (snippet.includes("base64")) {
559+
snippet = `import base64
560+
${snippet}`;
561+
}
562+
return snippet;
563+
};

packages/tasks-gen/scripts/generate-snippets-fixtures.ts

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,17 @@ const TEST_CASES: {
3131
providers: SnippetInferenceProvider[];
3232
opts?: Record<string, unknown>;
3333
}[] = [
34+
{
35+
testName: "automatic-speech-recognition",
36+
model: {
37+
id: "openai/whisper-large-v3-turbo",
38+
pipeline_tag: "automatic-speech-recognition",
39+
tags: [],
40+
inference: "",
41+
},
42+
languages: ["py"],
43+
providers: ["hf-inference"],
44+
},
3445
{
3546
testName: "conversational-llm-non-stream",
3647
model: {
@@ -79,6 +90,28 @@ const TEST_CASES: {
7990
providers: ["hf-inference", "fireworks-ai"],
8091
opts: { streaming: true },
8192
},
93+
{
94+
testName: "document-question-answering",
95+
model: {
96+
id: "impira/layoutlm-invoices",
97+
pipeline_tag: "document-question-answering",
98+
tags: [],
99+
inference: "",
100+
},
101+
languages: ["py"],
102+
providers: ["hf-inference"],
103+
},
104+
{
105+
testName: "image-to-image",
106+
model: {
107+
id: "stabilityai/stable-diffusion-xl-refiner-1.0",
108+
pipeline_tag: "image-to-image",
109+
tags: [],
110+
inference: "",
111+
},
112+
languages: ["py"],
113+
providers: ["hf-inference"],
114+
},
82115
{
83116
testName: "text-to-image",
84117
model: {
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from huggingface_hub import InferenceClient
2+
3+
client = InferenceClient(
4+
provider="hf-inference",
5+
api_key="api_token",
6+
)
7+
output = client.automatic_speech_recognition("sample1.flac", model="openai/whisper-large-v3-turbo")
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import requests
2+
3+
API_URL = "https://router.huggingface.co/hf-inference/models/openai/whisper-large-v3-turbo"
4+
headers = {"Authorization": "Bearer api_token"}
5+
6+
def query(filename):
7+
with open(filename, "rb") as f:
8+
data = f.read()
9+
response = requests.post(API_URL, headers=headers, data=data)
10+
return response.json()
11+
12+
output = query("sample1.flac")

packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/0.huggingface_hub.hf-inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from huggingface_hub import InferenceClient
22

33
client = InferenceClient(
4-
provider="hf-inference",
5-
api_key="api_token"
4+
provider="hf-inference",
5+
api_key="api_token",
66
)
77

88
messages = [

packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/0.huggingface_hub.together.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from huggingface_hub import InferenceClient
22

33
client = InferenceClient(
4-
provider="together",
5-
api_key="api_token"
4+
provider="together",
5+
api_key="api_token",
66
)
77

88
messages = [

0 commit comments

Comments
 (0)