17
17
from argparse import Namespace
18
18
19
19
import onnx
20
+ import safetensors_convert
20
21
import torch
21
- from huggingface_hub import hf_hub_download
22
- from transformers import pipeline , AutoTokenizer
22
+ from huggingface_hub import hf_hub_download , HfApi
23
+ from transformers import pipeline , AutoTokenizer , AutoConfig
23
24
24
25
from metadata import HuggingfaceMetadata
25
26
from shasum import sha1_sum
@@ -33,6 +34,12 @@ def __init__(self, tokenizer, model):
33
34
self .model = model
34
35
35
36
37
+ class ModelHolder (object ):
38
+
39
+ def __init__ (self , config ):
40
+ self .config = config
41
+
42
+
36
43
class HuggingfaceConverter :
37
44
38
45
def __init__ (self ):
@@ -43,10 +50,13 @@ def __init__(self):
43
50
self .translator = None
44
51
self .inputs = None
45
52
self .outputs = None
53
+ self .api = HfApi ()
46
54
47
55
def save_model (self , model_info , args : Namespace , temp_dir : str ):
48
56
if args .output_format == "OnnxRuntime" :
49
57
return self .save_onnx_model (model_info , args , temp_dir )
58
+ elif args .output_format == "Rust" :
59
+ return self .save_rust_model (model_info , args , temp_dir )
50
60
else :
51
61
return self .save_pytorch_model (model_info , args , temp_dir )
52
62
@@ -71,13 +81,67 @@ def save_onnx_model(self, model_info, args: Namespace, temp_dir: str):
71
81
include_types = "token_type_id" in inputs
72
82
73
83
tokenizer = AutoTokenizer .from_pretrained (model_id )
74
- hf_pipeline = PipelineHolder (tokenizer , model )
84
+ config = AutoConfig .from_pretrained (model_id )
85
+ hf_pipeline = PipelineHolder (tokenizer , ModelHolder (config ))
75
86
size = self .save_to_model_zoo (model_info , args .output_dir ,
76
87
"OnnxRuntime" , temp_dir , hf_pipeline ,
77
88
include_types )
78
89
79
90
return True , None , size
80
91
92
+ def save_rust_model (self , model_info , args : Namespace , temp_dir : str ):
93
+ model_id = model_info .modelId
94
+
95
+ config = AutoConfig .from_pretrained (model_id )
96
+ if hasattr (config , "model_type" ):
97
+ if config .model_type == "bert" :
98
+ include_types = True
99
+ elif config .model_type == "distilbert" :
100
+ include_types = False
101
+ else :
102
+ return False , f"Unsupported model_type: { config .model_type } " , - 1
103
+
104
+ logging .info (f"Saving rust model: { model_id } ..." )
105
+
106
+ if not os .path .exists (temp_dir ):
107
+ os .makedirs (temp_dir )
108
+
109
+ tokenizer = AutoTokenizer .from_pretrained (model_id )
110
+ hf_pipeline = PipelineHolder (tokenizer , ModelHolder (config ))
111
+ try :
112
+ # Save tokenizer.json to temp dir
113
+ self .save_tokenizer (hf_pipeline , temp_dir )
114
+ except Exception as e :
115
+ logging .warning (f"Failed to save tokenizer: { model_id } ." )
116
+ logging .warning (e , exc_info = True )
117
+ return False , "Failed to save tokenizer" , - 1
118
+
119
+ target = os .path .join (temp_dir , "model.safetensors" )
120
+ model = self .api .model_info (model_id , files_metadata = True )
121
+ has_sf_file = False
122
+ has_pt_file = False
123
+ for sibling in model .siblings :
124
+ if sibling .rfilename == "model.safetensors" :
125
+ has_sf_file = True
126
+ elif sibling .rfilename == "pytorch_model.bin" :
127
+ has_pt_file = True
128
+
129
+ if has_sf_file :
130
+ file = hf_hub_download (repo_id = model_id ,
131
+ filename = "model.safetensors" )
132
+ shutil .copyfile (file , target )
133
+ elif has_pt_file :
134
+ file = hf_hub_download (repo_id = model_id ,
135
+ filename = "pytorch_model.bin" )
136
+ safetensors_convert .convert_file (file , target )
137
+ else :
138
+ return False , f"No model file found for: { model_id } " , - 1
139
+
140
+ size = self .save_to_model_zoo (model_info , args .output_dir , "Rust" ,
141
+ temp_dir , hf_pipeline , include_types )
142
+
143
+ return True , None , size
144
+
81
145
def save_pytorch_model (self , model_info , args : Namespace , temp_dir : str ):
82
146
model_id = model_info .modelId
83
147
if not os .path .exists (temp_dir ):
@@ -134,7 +198,7 @@ def save_tokenizer(hf_pipeline, temp_dir: str):
134
198
hf_pipeline .tokenizer .save_pretrained (temp_dir )
135
199
# only keep tokenizer.json file
136
200
for path in os .listdir (temp_dir ):
137
- if path != "tokenizer.json" :
201
+ if path != "tokenizer.json" and path != "tokenizer_config.json" :
138
202
os .remove (os .path .join (temp_dir , path ))
139
203
140
204
def jit_trace_model (self , hf_pipeline , model_id : str , temp_dir : str ,
0 commit comments