Skip to content

Commit 1ebda4e

Browse files
authored
update scripts to support deepseekr1 (#3495)
1 parent 25fe8d0 commit 1ebda4e

File tree

7 files changed

+18
-7
lines changed

7 files changed

+18
-7
lines changed

examples/cpu/llm/inference/distributed/run_accuracy_with_deepspeed.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,11 @@ def get_repo_root(model_name_or_path):
402402
def get_checkpoint_files(model_name_or_path):
403403
cached_repo_dir = get_repo_root(model_name_or_path)
404404
glob_pattern = "*.[bp][it][n]"
405-
if re.search("deepseek-v2", model_name_or_path, re.IGNORECASE):
405+
if (
406+
re.search("deepseek-v2", model_name_or_path, re.IGNORECASE)
407+
or re.search("deepseek-v3", model_name_or_path, re.IGNORECASE)
408+
or re.search("deepseek-r1", model_name_or_path, re.IGNORECASE)
409+
):
406410
glob_pattern = "*.[sbp][ait][fn][e][t][e][n][s][o][r][s]"
407411
# extensions: .bin | .pt
408412
# creates a list of paths from all downloaded files in cache dir

examples/cpu/llm/inference/distributed/run_generation_with_deepspeed.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -294,8 +294,10 @@ def get_repo_root(model_name_or_path):
294294
def get_checkpoint_files(model_name_or_path):
295295
cached_repo_dir = get_repo_root(model_name_or_path)
296296
glob_pattern = "*.[bp][it][n]"
297-
if re.search("deepseek-v2", model_name_or_path, re.IGNORECASE) or re.search(
298-
"deepseek-v3", model_name_or_path, re.IGNORECASE
297+
if (
298+
re.search("deepseek-v2", model_name_or_path, re.IGNORECASE)
299+
or re.search("deepseek-v3", model_name_or_path, re.IGNORECASE)
300+
or re.search("deepseek-r1", model_name_or_path, re.IGNORECASE)
299301
):
300302
glob_pattern = "*.[sbp][ait][fn][e][t][e][n][s][o][r][s]"
301303
# extensions: .bin | .pt
@@ -328,7 +330,7 @@ def get_checkpoint_files(model_name_or_path):
328330
model_type = next((x for x in MODEL_CLASSES.keys() if x in model_name.lower()), "auto")
329331
if model_type == "llama" and args.vision_text_model:
330332
model_type = "mllama"
331-
if model_type in ["maira-2", "deepseek-v2", "deepseek-v3"]:
333+
if model_type in ["maira-2", "deepseek-v2", "deepseek-v3", "deepseek-r1"]:
332334
model_type = model_type.replace("-", "")
333335
model_class = MODEL_CLASSES[model_type]
334336
tokenizer = model_class[1].from_pretrained(model_name, trust_remote_code=True)

examples/cpu/llm/inference/run.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
595595
"jamba": ("/jamba_local_shard"),
596596
"deepseek-v2": ("/deepseekv2_local_shard"),
597597
"deepseek-v3": ("/deepseekv3_local_shard"),
598+
"deepseek-r1": ("/deepseekr1_local_shard"),
598599
}
599600
model_type = next(
600601
(

examples/cpu/llm/inference/single_instance/run_generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@
140140
)
141141
if model_type == "llama" and args.vision_text_model:
142142
model_type = "mllama"
143-
if model_type in ["maira-2", "deepseek-v2", "deepseek-v3"]:
143+
if model_type in ["maira-2", "deepseek-v2", "deepseek-v3", "deepseek-r1"]:
144144
model_type = model_type.replace("-", "")
145145
model_class = MODEL_CLASSES[model_type]
146146
if args.config_file is None:

examples/cpu/llm/inference/single_instance/run_quantization.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,8 @@ def download_and_open(url: str) -> Image.Image:
441441
model = DeepseekV2Config(args.model_id)
442442
elif re.search("deepseekv3", config.architectures[0], re.IGNORECASE):
443443
model = DeepseekV3Config(args.model_id)
444+
if "deepseek-r1" in args.model_id.lower() or "deepseekr1" in args.model_id.lower():
445+
model.name = "deepseekr1"
444446
else:
445447
raise AssertionError("Not support %s." % (args.model_id))
446448

examples/cpu/llm/inference/utils/create_shard_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
)
5454
if model_type == "llama" and args.vision_text_model:
5555
model_type = "mllama"
56-
if model_type in ["maira-2", "deepseek-v2", "deepseek-v3"]:
56+
if model_type in ["maira-2", "deepseek-v2", "deepseek-v3", "deepseek-r1"]:
5757
model_type = model_type.replace("-", "")
5858
model_class = MODEL_CLASSES[model_type]
5959
load_dtype = torch.float32
@@ -83,7 +83,7 @@
8383
tokenizer.save_pretrained(save_directory=args.save_path)
8484
if model_type == "llava":
8585
image_processor.save_pretrained(save_directory=args.save_path)
86-
if model_type in ["maira2", "deepseekv2", "deepseekv3"]:
86+
if model_type in ["maira2", "deepseekv2", "deepseekv3", "deepseekr1"]:
8787
import inspect
8888
import shutil
8989

examples/cpu/llm/inference/utils/supported_models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,10 @@
3737
"jamba": (AutoModelForCausalLM, AutoTokenizer),
3838
"deepseek-v2": (AutoModelForCausalLM, AutoTokenizer),
3939
"deepseek-v3": (AutoModelForCausalLM, AutoTokenizer),
40+
"deepseek-r1": (AutoModelForCausalLM, AutoTokenizer),
4041
"deepseekv2": (AutoModelForCausalLM, AutoTokenizer),
4142
"deepseekv3": (AutoModelForCausalLM, AutoTokenizer),
43+
"deepseekr1": (AutoModelForCausalLM, AutoTokenizer),
4244
"auto": (AutoModelForCausalLM, AutoTokenizer),
4345
}
4446

0 commit comments

Comments
 (0)