Skip to content

Commit 2bf7583

Browse files
authored
[hf_modelzoo] Adds import rust model from Huggingface (#3125)
1 parent d0f38f3 commit 2bf7583

File tree

4 files changed

+171
-5
lines changed

4 files changed

+171
-5
lines changed

extensions/tokenizers/src/main/python/arg_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def converter_args():
2626
parser.add_argument("-f",
2727
"--output-format",
2828
default="PyTorch",
29-
choices=["PyTorch", "OnnxRuntime"],
29+
choices=["PyTorch", "OnnxRuntime", "Rust"],
3030
help="Model output format")
3131
parser.add_argument("-r",
3232
"--retry-failed",

extensions/tokenizers/src/main/python/huggingface_converter.py

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717
from argparse import Namespace
1818

1919
import onnx
20+
import safetensors_convert
2021
import torch
21-
from huggingface_hub import hf_hub_download
22-
from transformers import pipeline, AutoTokenizer
22+
from huggingface_hub import hf_hub_download, HfApi
23+
from transformers import pipeline, AutoTokenizer, AutoConfig
2324

2425
from metadata import HuggingfaceMetadata
2526
from shasum import sha1_sum
@@ -33,6 +34,12 @@ def __init__(self, tokenizer, model):
3334
self.model = model
3435

3536

37+
class ModelHolder(object):
38+
39+
def __init__(self, config):
40+
self.config = config
41+
42+
3643
class HuggingfaceConverter:
3744

3845
def __init__(self):
@@ -43,10 +50,13 @@ def __init__(self):
4350
self.translator = None
4451
self.inputs = None
4552
self.outputs = None
53+
self.api = HfApi()
4654

4755
def save_model(self, model_info, args: Namespace, temp_dir: str):
4856
if args.output_format == "OnnxRuntime":
4957
return self.save_onnx_model(model_info, args, temp_dir)
58+
elif args.output_format == "Rust":
59+
return self.save_rust_model(model_info, args, temp_dir)
5060
else:
5161
return self.save_pytorch_model(model_info, args, temp_dir)
5262

@@ -71,13 +81,67 @@ def save_onnx_model(self, model_info, args: Namespace, temp_dir: str):
7181
include_types = "token_type_id" in inputs
7282

7383
tokenizer = AutoTokenizer.from_pretrained(model_id)
74-
hf_pipeline = PipelineHolder(tokenizer, model)
84+
config = AutoConfig.from_pretrained(model_id)
85+
hf_pipeline = PipelineHolder(tokenizer, ModelHolder(config))
7586
size = self.save_to_model_zoo(model_info, args.output_dir,
7687
"OnnxRuntime", temp_dir, hf_pipeline,
7788
include_types)
7889

7990
return True, None, size
8091

92+
def save_rust_model(self, model_info, args: Namespace, temp_dir: str):
93+
model_id = model_info.modelId
94+
95+
config = AutoConfig.from_pretrained(model_id)
96+
if hasattr(config, "model_type"):
97+
if config.model_type == "bert":
98+
include_types = True
99+
elif config.model_type == "distilbert":
100+
include_types = False
101+
else:
102+
return False, f"Unsupported model_type: {config.model_type}", -1
103+
104+
logging.info(f"Saving rust model: {model_id} ...")
105+
106+
if not os.path.exists(temp_dir):
107+
os.makedirs(temp_dir)
108+
109+
tokenizer = AutoTokenizer.from_pretrained(model_id)
110+
hf_pipeline = PipelineHolder(tokenizer, ModelHolder(config))
111+
try:
112+
# Save tokenizer.json to temp dir
113+
self.save_tokenizer(hf_pipeline, temp_dir)
114+
except Exception as e:
115+
logging.warning(f"Failed to save tokenizer: {model_id}.")
116+
logging.warning(e, exc_info=True)
117+
return False, "Failed to save tokenizer", -1
118+
119+
target = os.path.join(temp_dir, "model.safetensors")
120+
model = self.api.model_info(model_id, files_metadata=True)
121+
has_sf_file = False
122+
has_pt_file = False
123+
for sibling in model.siblings:
124+
if sibling.rfilename == "model.safetensors":
125+
has_sf_file = True
126+
elif sibling.rfilename == "pytorch_model.bin":
127+
has_pt_file = True
128+
129+
if has_sf_file:
130+
file = hf_hub_download(repo_id=model_id,
131+
filename="model.safetensors")
132+
shutil.copyfile(file, target)
133+
elif has_pt_file:
134+
file = hf_hub_download(repo_id=model_id,
135+
filename="pytorch_model.bin")
136+
safetensors_convert.convert_file(file, target)
137+
else:
138+
return False, f"No model file found for: {model_id}", -1
139+
140+
size = self.save_to_model_zoo(model_info, args.output_dir, "Rust",
141+
temp_dir, hf_pipeline, include_types)
142+
143+
return True, None, size
144+
81145
def save_pytorch_model(self, model_info, args: Namespace, temp_dir: str):
82146
model_id = model_info.modelId
83147
if not os.path.exists(temp_dir):
@@ -134,7 +198,7 @@ def save_tokenizer(hf_pipeline, temp_dir: str):
134198
hf_pipeline.tokenizer.save_pretrained(temp_dir)
135199
# only keep tokenizer.json file
136200
for path in os.listdir(temp_dir):
137-
if path != "tokenizer.json":
201+
if path != "tokenizer.json" and path != "tokenizer_config.json":
138202
os.remove(os.path.join(temp_dir, path))
139203

140204
def jit_trace_model(self, hf_pipeline, model_id: str, temp_dir: str,

extensions/tokenizers/src/main/python/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ transformers
33
torch
44
protobuf==3.20.2
55
optimum[exporters,onnxruntime]
6+
safetensors
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
#!/usr/bin/env python
2+
#
3+
# Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
6+
# except in compliance with the License. A copy of the License is located at
7+
#
8+
# http://aws.amazon.com/apache2.0/
9+
#
10+
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
11+
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
12+
# the specific language governing permissions and limitations under the License.
13+
import os
14+
from collections import defaultdict
15+
from typing import List, Dict
16+
17+
import torch
18+
from safetensors.torch import _find_shared_tensors, _is_complete, load_file, save_file
19+
20+
21+
def _remove_duplicate_names(
22+
state_dict: Dict[str, torch.Tensor],
23+
*,
24+
preferred_names: List[str] = None,
25+
discard_names: List[str] = None,
26+
) -> Dict[str, List[str]]:
27+
if preferred_names is None:
28+
preferred_names = []
29+
preferred_names = set(preferred_names)
30+
if discard_names is None:
31+
discard_names = []
32+
discard_names = set(discard_names)
33+
34+
shareds = _find_shared_tensors(state_dict)
35+
to_remove = defaultdict(list)
36+
for shared in shareds:
37+
complete_names = set(
38+
[name for name in shared if _is_complete(state_dict[name])])
39+
if not complete_names:
40+
if len(shared) == 1:
41+
# Force contiguous
42+
name = list(shared)[0]
43+
state_dict[name] = state_dict[name].clone()
44+
complete_names = {name}
45+
else:
46+
raise RuntimeError(
47+
f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue."
48+
)
49+
50+
keep_name = sorted(list(complete_names))[0]
51+
52+
preferred = complete_names.difference(discard_names)
53+
if preferred:
54+
keep_name = sorted(list(preferred))[0]
55+
56+
if preferred_names:
57+
preferred = preferred_names.intersection(complete_names)
58+
if preferred:
59+
keep_name = sorted(list(preferred))[0]
60+
for name in sorted(shared):
61+
if name != keep_name:
62+
to_remove[keep_name].append(name)
63+
return to_remove
64+
65+
66+
def convert_file(pt_filename: str, sf_filename: str):
67+
loaded = torch.load(pt_filename, map_location="cpu")
68+
if "state_dict" in loaded:
69+
loaded = loaded["state_dict"]
70+
to_removes = _remove_duplicate_names(loaded)
71+
72+
metadata = {"format": "pt"}
73+
for kept_name, to_remove_group in to_removes.items():
74+
for to_remove in to_remove_group:
75+
if to_remove not in metadata:
76+
metadata[to_remove] = kept_name
77+
del loaded[to_remove]
78+
# Force tensors to be contiguous
79+
loaded = {k: v.contiguous() for k, v in loaded.items()}
80+
81+
dir_name = os.path.dirname(sf_filename)
82+
os.makedirs(dir_name, exist_ok=True)
83+
save_file(loaded, sf_filename, metadata=metadata)
84+
check_file_size(sf_filename, pt_filename)
85+
reloaded = load_file(sf_filename)
86+
for k in loaded:
87+
pt_tensor = loaded[k]
88+
sf_tensor = reloaded[k]
89+
if not torch.equal(pt_tensor, sf_tensor):
90+
raise RuntimeError(f"The output tensors do not match for key {k}")
91+
92+
93+
def check_file_size(sf_filename: str, pt_filename: str):
94+
sf_size = os.stat(sf_filename).st_size
95+
pt_size = os.stat(pt_filename).st_size
96+
97+
if (sf_size - pt_size) / pt_size > 0.01:
98+
raise RuntimeError(f"""The file size different is more than 1%:
99+
- {sf_filename}: {sf_size}
100+
- {pt_filename}: {pt_size}
101+
""")

0 commit comments

Comments
 (0)