21
21
import torch
22
22
from intel_extension_for_transformers .transformers import OptimizedModel
23
23
from intel_extension_for_transformers .transformers .utils .utility import LazyImport
24
- import transformers
24
+ from collections import OrderedDict
25
25
from transformers import T5Config , MT5Config
26
26
from typing import Union , Optional
27
+ from .utils import get_module_path
27
28
28
29
sentence_transformers = LazyImport ("sentence_transformers" )
29
30
@@ -53,12 +54,84 @@ def __init__(self, *args, **kwargs):
53
54
"""Initialize the OptimizedSentenceTransformer."""
54
55
super ().__init__ (* args , ** kwargs )
55
56
56
- def _load_auto_model (self , model_name_or_path : str , token : Optional [Union [bool , str ]], cache_folder : Optional [str ]):
57
+ def _load_auto_model (
58
+ self , model_name_or_path : str , token : Optional [Union [bool , str ]], cache_folder : Optional [str ]):
57
59
"""
58
60
Creates a simple Transformer + Mean Pooling model and returns the modules
59
61
"""
60
62
logger .warning ("No sentence-transformers model found with name {}." \
61
63
"Creating a new one with MEAN pooling." .format (model_name_or_path ))
62
- transformer_model = OptimzedTransformer (model_name_or_path , cache_dir = cache_folder , model_args = {"token" : token })
63
- pooling_model = sentence_transformers .models .Pooling (transformer_model .get_word_embedding_dimension (), 'mean' )
64
- return [transformer_model , pooling_model ]
64
+ transformer_model = OptimzedTransformer (
65
+ model_name_or_path , cache_dir = cache_folder , model_args = {"token" : token })
66
+ pooling_model = sentence_transformers .models .Pooling (
67
+ transformer_model .get_word_embedding_dimension (), 'mean' )
68
+ return [transformer_model , pooling_model ]
69
+
70
+ def _load_sbert_model (
71
+ self , model_name_or_path : str , token : Optional [Union [bool , str ]], cache_folder : Optional [str ]):
72
+ """
73
+ Loads a full sentence-transformers model
74
+ """
75
+ # Check if the config_sentence_transformers.json file exists (exists since v2 of the framework)
76
+ config_sentence_transformers_json_path = sentence_transformers .util .load_file_path (
77
+ model_name_or_path , 'config_sentence_transformers.json' , token = token , cache_folder = cache_folder )
78
+ if config_sentence_transformers_json_path is not None :
79
+ with open (config_sentence_transformers_json_path ) as fIn :
80
+ self ._model_config = json .load (fIn )
81
+
82
+ if '__version__' in self ._model_config and \
83
+ 'sentence_transformers' in self ._model_config ['__version__' ] and \
84
+ self ._model_config ['__version__' ]['sentence_transformers' ] > sentence_transformers .__version__ :
85
+ logger .warning ("You try to use a model that was created with version {}, " \
86
+ "however, your version is {}. This might cause unexpected " \
87
+ "behavior or errors. In that case, try to update to the " \
88
+ "latest version.\n \n \n " .format (
89
+ self ._model_config ['__version__' ]['sentence_transformers' ],
90
+ sentence_transformers .__version__ ))
91
+
92
+ # Check if a readme exists
93
+ model_card_path = sentence_transformers .util .load_file_path (
94
+ model_name_or_path , 'README.md' , token = token , cache_folder = cache_folder )
95
+ if model_card_path is not None :
96
+ try :
97
+ with open (model_card_path , encoding = 'utf8' ) as fIn :
98
+ self ._model_card_text = fIn .read ()
99
+ except :
100
+ pass
101
+
102
+ # Load the modules of sentence transformer
103
+ modules_json_path = sentence_transformers .util .load_file_path (
104
+ model_name_or_path , 'modules.json' , token = token , cache_folder = cache_folder )
105
+ with open (modules_json_path ) as fIn :
106
+ modules_config = json .load (fIn )
107
+
108
+ modules = OrderedDict ()
109
+ for module_config in modules_config :
110
+ module_class = sentence_transformers .util .import_from_string (module_config ['type' ])
111
+ # For Transformer, don't load the full directory, rely on `transformers` instead
112
+ # But, do load the config file first.
113
+ if module_class == sentence_transformers .models .Transformer and module_config ['path' ] == "" :
114
+ kwargs = {}
115
+ for config_name in ['sentence_bert_config.json' , 'sentence_roberta_config.json' ,
116
+ 'sentence_distilbert_config.json' , 'sentence_camembert_config.json' ,
117
+ 'sentence_albert_config.json' , 'sentence_xlm-roberta_config.json' ,
118
+ 'sentence_xlnet_config.json' ]:
119
+ config_path = sentence_transformers .util .load_file_path (
120
+ model_name_or_path , config_name , token = token , cache_folder = cache_folder )
121
+ if config_path is not None :
122
+ with open (config_path ) as fIn :
123
+ kwargs = json .load (fIn )
124
+ break
125
+ if "model_args" in kwargs :
126
+ kwargs ["model_args" ]["token" ] = token
127
+ else :
128
+ kwargs ["model_args" ] = {"token" : token }
129
+ module = sentence_transformers .models .Transformer (
130
+ model_name_or_path , cache_dir = cache_folder , ** kwargs )
131
+ else :
132
+ module_path = get_module_path (
133
+ model_name_or_path , module_config ['path' ], token = token , cache_folder = cache_folder )
134
+ module = module_class .load (module_path )
135
+ modules [module_config ['name' ]] = module
136
+
137
+ return modules
0 commit comments