@@ -44,11 +44,11 @@ const snippetImportInferenceClient = (accessToken: string, provider: SnippetInfe
44
44
from huggingface_hub import InferenceClient
45
45
46
46
client = InferenceClient(
47
- provider="${ provider } ",
48
- api_key="${ accessToken || "{API_TOKEN}" } "
47
+ provider="${ provider } ",
48
+ api_key="${ accessToken || "{API_TOKEN}" } ",
49
49
)` ;
50
50
51
- export const snippetConversational = (
51
+ const snippetConversational = (
52
52
model : ModelDataMinimal ,
53
53
accessToken : string ,
54
54
provider : SnippetInferenceProvider ,
@@ -89,7 +89,7 @@ stream = client.chat.completions.create(
89
89
model="${ model . id } ",
90
90
messages=messages,
91
91
${ configStr }
92
- stream=True
92
+ stream=True,
93
93
)
94
94
95
95
for chunk in stream:
@@ -159,7 +159,7 @@ print(completion.choices[0].message)`,
159
159
}
160
160
} ;
161
161
162
- export const snippetZeroShotClassification = ( model : ModelDataMinimal ) : InferenceSnippet [ ] => {
162
+ const snippetZeroShotClassification = ( model : ModelDataMinimal ) : InferenceSnippet [ ] => {
163
163
return [
164
164
{
165
165
client : "requests" ,
@@ -176,12 +176,11 @@ output = query({
176
176
] ;
177
177
} ;
178
178
179
- export const snippetZeroShotImageClassification = ( model : ModelDataMinimal ) : InferenceSnippet [ ] => {
179
+ const snippetZeroShotImageClassification = ( model : ModelDataMinimal ) : InferenceSnippet [ ] => {
180
180
return [
181
181
{
182
182
client : "requests" ,
183
- content : `\
184
- def query(data):
183
+ content : `def query(data):
185
184
with open(data["image_path"], "rb") as f:
186
185
img = f.read()
187
186
payload={
@@ -199,7 +198,7 @@ output = query({
199
198
] ;
200
199
} ;
201
200
202
- export const snippetBasic = (
201
+ const snippetBasic = (
203
202
model : ModelDataMinimal ,
204
203
accessToken : string ,
205
204
provider : SnippetInferenceProvider
@@ -213,9 +212,8 @@ export const snippetBasic = (
213
212
${ snippetImportInferenceClient ( accessToken , provider ) }
214
213
215
214
result = client.${ HFH_INFERENCE_CLIENT_METHODS [ model . pipeline_tag ] } (
216
- model="${ model . id } ",
217
215
inputs=${ getModelInputSnippet ( model ) } ,
218
- provider ="${ provider } ",
216
+ model ="${ model . id } ",
219
217
)
220
218
221
219
print(result)
@@ -237,7 +235,7 @@ output = query({
237
235
] ;
238
236
} ;
239
237
240
- export const snippetFile = ( model : ModelDataMinimal ) : InferenceSnippet [ ] => {
238
+ const snippetFile = ( model : ModelDataMinimal ) : InferenceSnippet [ ] => {
241
239
return [
242
240
{
243
241
client : "requests" ,
@@ -253,7 +251,7 @@ output = query(${getModelInputSnippet(model)})`,
253
251
] ;
254
252
} ;
255
253
256
- export const snippetTextToImage = (
254
+ const snippetTextToImage = (
257
255
model : ModelDataMinimal ,
258
256
accessToken : string ,
259
257
provider : SnippetInferenceProvider ,
@@ -268,7 +266,7 @@ ${snippetImportInferenceClient(accessToken, provider)}
268
266
# output is a PIL.Image object
269
267
image = client.text_to_image(
270
268
${ getModelInputSnippet ( model ) } ,
271
- model="${ model . id } "
269
+ model="${ model . id } ",
272
270
)` ,
273
271
} ,
274
272
...( provider === "fal-ai"
@@ -312,7 +310,7 @@ image = Image.open(io.BytesIO(image_bytes))`,
312
310
] ;
313
311
} ;
314
312
315
- export const snippetTextToVideo = (
313
+ const snippetTextToVideo = (
316
314
model : ModelDataMinimal ,
317
315
accessToken : string ,
318
316
provider : SnippetInferenceProvider
@@ -326,14 +324,14 @@ ${snippetImportInferenceClient(accessToken, provider)}
326
324
327
325
video = client.text_to_video(
328
326
${ getModelInputSnippet ( model ) } ,
329
- model="${ model . id } "
327
+ model="${ model . id } ",
330
328
)` ,
331
329
} ,
332
330
]
333
331
: [ ] ;
334
332
} ;
335
333
336
- export const snippetTabular = ( model : ModelDataMinimal ) : InferenceSnippet [ ] => {
334
+ const snippetTabular = ( model : ModelDataMinimal ) : InferenceSnippet [ ] => {
337
335
return [
338
336
{
339
337
client : "requests" ,
@@ -349,7 +347,7 @@ response = query({
349
347
] ;
350
348
} ;
351
349
352
- export const snippetTextToAudio = ( model : ModelDataMinimal ) : InferenceSnippet [ ] => {
350
+ const snippetTextToAudio = ( model : ModelDataMinimal ) : InferenceSnippet [ ] => {
353
351
// Transformers TTS pipeline and api-inference-community (AIC) pipeline outputs are diverged
354
352
// with the latest update to inference-api (IA).
355
353
// Transformers IA returns a byte object (wav file), whereas AIC returns wav and sampling_rate.
@@ -374,8 +372,7 @@ Audio(audio_bytes)`,
374
372
return [
375
373
{
376
374
client : "requests" ,
377
- content : `\
378
- def query(payload):
375
+ content : `def query(payload):
379
376
response = requests.post(API_URL, headers=headers, json=payload)
380
377
return response.json()
381
378
@@ -390,26 +387,97 @@ Audio(audio, rate=sampling_rate)`,
390
387
}
391
388
} ;
392
389
393
- export const snippetDocumentQuestionAnswering = ( model : ModelDataMinimal ) : InferenceSnippet [ ] => {
390
+ const snippetAutomaticSpeechRecognition = (
391
+ model : ModelDataMinimal ,
392
+ accessToken : string ,
393
+ provider : SnippetInferenceProvider
394
+ ) : InferenceSnippet [ ] => {
395
+ return [
396
+ {
397
+ client : "huggingface_hub" ,
398
+ content : `${ snippetImportInferenceClient ( accessToken , provider ) }
399
+ output = client.automatic_speech_recognition(${ getModelInputSnippet ( model ) } , model="${ model . id } ")` ,
400
+ } ,
401
+ snippetFile ( model ) [ 0 ] ,
402
+ ] ;
403
+ } ;
404
+
405
+ const snippetDocumentQuestionAnswering = (
406
+ model : ModelDataMinimal ,
407
+ accessToken : string ,
408
+ provider : SnippetInferenceProvider
409
+ ) : InferenceSnippet [ ] => {
410
+ const inputsAsStr = getModelInputSnippet ( model ) as string ;
411
+ const inputsAsObj = JSON . parse ( inputsAsStr ) ;
412
+
394
413
return [
414
+ {
415
+ client : "huggingface_hub" ,
416
+ content : `${ snippetImportInferenceClient ( accessToken , provider ) }
417
+ output = client.document_question_answering(
418
+ "${ inputsAsObj . image } ",
419
+ question="${ inputsAsObj . question } ",
420
+ model="${ model . id } ",
421
+ )` ,
422
+ } ,
395
423
{
396
424
client : "requests" ,
397
- content : `\
398
- def query(payload):
425
+ content : `def query(payload):
399
426
with open(payload["image"], "rb") as f:
400
427
img = f.read()
401
- payload["image"] = base64.b64encode(img).decode("utf-8")
428
+ payload["image"] = base64.b64encode(img).decode("utf-8")
402
429
response = requests.post(API_URL, headers=headers, json=payload)
403
430
return response.json()
404
431
405
432
output = query({
406
- "inputs": ${ getModelInputSnippet ( model ) } ,
433
+ "inputs": ${ inputsAsStr } ,
407
434
})` ,
408
435
} ,
409
436
] ;
410
437
} ;
411
438
412
- export const pythonSnippets : Partial <
439
+ const snippetImageToImage = (
440
+ model : ModelDataMinimal ,
441
+ accessToken : string ,
442
+ provider : SnippetInferenceProvider
443
+ ) : InferenceSnippet [ ] => {
444
+ const inputsAsStr = getModelInputSnippet ( model ) as string ;
445
+ const inputsAsObj = JSON . parse ( inputsAsStr ) ;
446
+
447
+ return [
448
+ {
449
+ client : "huggingface_hub" ,
450
+ content : `${ snippetImportInferenceClient ( accessToken , provider ) }
451
+ # output is a PIL.Image object
452
+ image = client.image_to_image(
453
+ "${ inputsAsObj . image } ",
454
+ prompt="${ inputsAsObj . prompt } ",
455
+ model="${ model . id } ",
456
+ )` ,
457
+ } ,
458
+ {
459
+ client : "requests" ,
460
+ content : `def query(payload):
461
+ with open(payload["inputs"], "rb") as f:
462
+ img = f.read()
463
+ payload["inputs"] = base64.b64encode(img).decode("utf-8")
464
+ response = requests.post(API_URL, headers=headers, json=payload)
465
+ return response.content
466
+
467
+ image_bytes = query({
468
+ "inputs": "${ inputsAsObj . image } ",
469
+ "parameters": {"prompt": "${ inputsAsObj . prompt } "},
470
+ })
471
+
472
+ # You can access the image with PIL.Image for example
473
+ import io
474
+ from PIL import Image
475
+ image = Image.open(io.BytesIO(image_bytes))` ,
476
+ } ,
477
+ ] ;
478
+ } ;
479
+
480
+ const pythonSnippets : Partial <
413
481
Record <
414
482
PipelineType ,
415
483
(
@@ -435,7 +503,7 @@ export const pythonSnippets: Partial<
435
503
"image-text-to-text" : snippetConversational ,
436
504
"fill-mask" : snippetBasic ,
437
505
"sentence-similarity" : snippetBasic ,
438
- "automatic-speech-recognition" : snippetFile ,
506
+ "automatic-speech-recognition" : snippetAutomaticSpeechRecognition ,
439
507
"text-to-image" : snippetTextToImage ,
440
508
"text-to-video" : snippetTextToVideo ,
441
509
"text-to-speech" : snippetTextToAudio ,
@@ -449,6 +517,7 @@ export const pythonSnippets: Partial<
449
517
"image-segmentation" : snippetFile ,
450
518
"document-question-answering" : snippetDocumentQuestionAnswering ,
451
519
"image-to-text" : snippetFile ,
520
+ "image-to-image" : snippetImageToImage ,
452
521
"zero-shot-image-classification" : snippetZeroShotImageClassification ,
453
522
} ;
454
523
@@ -471,17 +540,24 @@ export function getPythonInferenceSnippet(
471
540
return snippets . map ( ( snippet ) => {
472
541
return {
473
542
...snippet ,
474
- content :
475
- snippet . client === "requests"
476
- ? `\
477
- import requests
478
-
479
- API_URL = "${ openAIbaseUrl ( provider ) } "
480
- headers = {"Authorization": ${ accessToken ? `"Bearer ${ accessToken } "` : `f"Bearer {API_TOKEN}"` } }
481
-
482
- ${ snippet . content } `
483
- : snippet . content ,
543
+ content : addImportsToSnippet ( snippet . content , model , accessToken ) ,
484
544
} ;
485
545
} ) ;
486
546
}
487
547
}
548
+
549
+ const addImportsToSnippet = ( snippet : string , model : ModelDataMinimal , accessToken : string ) : string => {
550
+ if ( snippet . includes ( "requests" ) ) {
551
+ snippet = `import requests
552
+
553
+ API_URL = "https://router.huggingface.co/hf-inference/models/${ model . id } "
554
+ headers = {"Authorization": ${ accessToken ? `"Bearer ${ accessToken } "` : `f"Bearer {API_TOKEN}"` } }
555
+
556
+ ${ snippet } `;
557
+ }
558
+ if ( snippet . includes ( "base64" ) ) {
559
+ snippet = `import base64
560
+ ${ snippet } `;
561
+ }
562
+ return snippet ;
563
+ } ;
0 commit comments