Skip to content

Commit a4ca182

Browse files
pcuencaVaibhavs10
andauthored
Richer conversational snippet for AutoModel (#1611)
I can't remember if we already discussed about this. This is an attempt to make the `AutoModel` snippet more self-contained, for the particular but frequent case of conversational transformers models. The pipeline version already got this treatment in #1434, but the `AutoModel` version is currently too barebones in comparison, as shown in the screenshot below. In my opinion this could reduce the friction for users that want to experiment with the models in Colab or Kaggle. cc @Vaibhavs10 @ariG23498 <img width="789" height="534" alt="Screenshot 2025-07-11 at 14 05 15" src="https://github.com/user-attachments/assets/558837e3-ec56-4100-afa1-7a54d3d79305" /> --------- Co-authored-by: vb <[email protected]>
1 parent 2220487 commit a4ca182

File tree

1 file changed

+42
-11
lines changed

1 file changed

+42
-11
lines changed

packages/tasks/src/model-libraries-snippets.ts

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,27 +1373,58 @@ export const transformers = (model: ModelData): string[] => {
13731373
}
13741374
const remote_code_snippet = model.tags.includes(TAG_CUSTOM_CODE) ? ", trust_remote_code=True" : "";
13751375

1376-
let autoSnippet: string;
1376+
const autoSnippet = [];
13771377
if (info.processor) {
1378-
const varName =
1378+
const processorVarName =
13791379
info.processor === "AutoTokenizer"
13801380
? "tokenizer"
13811381
: info.processor === "AutoFeatureExtractor"
13821382
? "extractor"
13831383
: "processor";
1384-
autoSnippet = [
1384+
autoSnippet.push(
13851385
"# Load model directly",
13861386
`from transformers import ${info.processor}, ${info.auto_model}`,
13871387
"",
1388-
`${varName} = ${info.processor}.from_pretrained("${model.id}"` + remote_code_snippet + ")",
1389-
`model = ${info.auto_model}.from_pretrained("${model.id}"` + remote_code_snippet + ")",
1390-
].join("\n");
1388+
`${processorVarName} = ${info.processor}.from_pretrained("${model.id}"` + remote_code_snippet + ")",
1389+
`model = ${info.auto_model}.from_pretrained("${model.id}"` + remote_code_snippet + ")"
1390+
);
1391+
if (model.tags.includes("conversational")) {
1392+
if (model.tags.includes("image-text-to-text")) {
1393+
autoSnippet.push(
1394+
"messages = [",
1395+
[
1396+
" {",
1397+
' "role": "user",',
1398+
' "content": [',
1399+
' {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/p-blog/candy.JPG"},',
1400+
' {"type": "text", "text": "What animal is on the candy?"}',
1401+
" ]",
1402+
" },",
1403+
].join("\n"),
1404+
"]"
1405+
);
1406+
} else {
1407+
autoSnippet.push("messages = [", ' {"role": "user", "content": "Who are you?"},', "]");
1408+
}
1409+
autoSnippet.push(
1410+
"inputs = ${processorVarName}.apply_chat_template(",
1411+
" messages,",
1412+
" add_generation_prompt=True,",
1413+
" tokenize=True,",
1414+
" return_dict=True,",
1415+
' return_tensors="pt",',
1416+
").to(model.device)",
1417+
"",
1418+
"outputs = model.generate(**inputs, max_new_tokens=40)",
1419+
'print(${processorVarName}.decode(outputs[0][inputs["input_ids"].shape[-1]:]))'
1420+
);
1421+
}
13911422
} else {
1392-
autoSnippet = [
1423+
autoSnippet.push(
13931424
"# Load model directly",
13941425
`from transformers import ${info.auto_model}`,
1395-
`model = ${info.auto_model}.from_pretrained("${model.id}"` + remote_code_snippet + ")",
1396-
].join("\n");
1426+
`model = ${info.auto_model}.from_pretrained("${model.id}"` + remote_code_snippet + ', torch_dtype="auto"),'
1427+
);
13971428
}
13981429

13991430
if (model.pipeline_tag && LIBRARY_TASK_MAPPING.transformers?.includes(model.pipeline_tag)) {
@@ -1437,9 +1468,9 @@ export const transformers = (model: ModelData): string[] => {
14371468
);
14381469
}
14391470

1440-
return [pipelineSnippet.join("\n"), autoSnippet];
1471+
return [pipelineSnippet.join("\n"), autoSnippet.join("\n")];
14411472
}
1442-
return [autoSnippet];
1473+
return [autoSnippet.join("\n")];
14431474
};
14441475

14451476
export const transformersJS = (model: ModelData): string[] => {

0 commit comments

Comments
 (0)