Skip to content

Commit 59f15e3

Browse files
authored
Fix embeddings to support local file (intel#976)
1 parent 6185971 commit 59f15e3

File tree

4 files changed

+136
-11
lines changed

4 files changed

+136
-11
lines changed

intel_extension_for_transformers/langchain/embeddings/optimized_instructor_embedding.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from intel_extension_for_transformers.transformers.utils.utility import LazyImport
2424
from transformers import T5Config, MT5Config
2525
from typing import Union, Optional
26-
26+
from .utils import get_module_path
2727
from .optimized_sentence_transformers import OptimzedTransformer
2828

2929
sentence_transformers = LazyImport("sentence_transformers")
@@ -126,14 +126,14 @@ def _load_sbert_model(self,
126126
module = OptimizedInstructorTransformer(model_name_or_path, cache_dir=cache_folder, **kwargs)
127127
elif module_config['idx']==1:
128128
module_class = InstructorEmbedding.INSTRUCTOR_Pooling
129-
module_path = sentence_transformers.util.load_dir_path(
130-
model_name_or_path, module_config['path'], token=token, cache_folder=cache_folder)
129+
module_path = get_module_path(
130+
model_name_or_path, module_config['path'], token, cache_folder)
131131
module = module_class.load(module_path)
132132
else:
133133
module_class = InstructorEmbedding.import_from_string(module_config['type'])
134-
module_path = sentence_transformers.util.load_dir_path(
135-
model_name_or_path, module_config['path'], token=token, cache_folder=cache_folder)
134+
module_path = get_module_path(
135+
model_name_or_path, module_config['path'], token, cache_folder)
136136
module = module_class.load(module_path)
137137
modules[module_config['name']] = module
138138

139-
return modules
139+
return modules

intel_extension_for_transformers/langchain/embeddings/optimized_sentence_transformers.py

Lines changed: 78 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@
2121
import torch
2222
from intel_extension_for_transformers.transformers import OptimizedModel
2323
from intel_extension_for_transformers.transformers.utils.utility import LazyImport
24-
import transformers
24+
from collections import OrderedDict
2525
from transformers import T5Config, MT5Config
2626
from typing import Union, Optional
27+
from .utils import get_module_path
2728

2829
sentence_transformers = LazyImport("sentence_transformers")
2930

@@ -53,12 +54,84 @@ def __init__(self, *args, **kwargs):
5354
"""Initialize the OptimizedSentenceTransformer."""
5455
super().__init__(*args, **kwargs)
5556

56-
def _load_auto_model(self, model_name_or_path: str, token: Optional[Union[bool, str]], cache_folder: Optional[str]):
57+
def _load_auto_model(
58+
self, model_name_or_path: str, token: Optional[Union[bool, str]], cache_folder: Optional[str]):
5759
"""
5860
Creates a simple Transformer + Mean Pooling model and returns the modules
5961
"""
6062
logger.warning("No sentence-transformers model found with name {}." \
6163
"Creating a new one with MEAN pooling.".format(model_name_or_path))
62-
transformer_model = OptimzedTransformer(model_name_or_path, cache_dir=cache_folder, model_args={"token": token})
63-
pooling_model = sentence_transformers.models.Pooling(transformer_model.get_word_embedding_dimension(), 'mean')
64-
return [transformer_model, pooling_model]
64+
transformer_model = OptimzedTransformer(
65+
model_name_or_path, cache_dir=cache_folder, model_args={"token": token})
66+
pooling_model = sentence_transformers.models.Pooling(
67+
transformer_model.get_word_embedding_dimension(), 'mean')
68+
return [transformer_model, pooling_model]
69+
70+
def _load_sbert_model(
71+
self, model_name_or_path: str, token: Optional[Union[bool, str]], cache_folder: Optional[str]):
72+
"""
73+
Loads a full sentence-transformers model
74+
"""
75+
# Check if the config_sentence_transformers.json file exists (exists since v2 of the framework)
76+
config_sentence_transformers_json_path = sentence_transformers.util.load_file_path(
77+
model_name_or_path, 'config_sentence_transformers.json', token=token, cache_folder=cache_folder)
78+
if config_sentence_transformers_json_path is not None:
79+
with open(config_sentence_transformers_json_path) as fIn:
80+
self._model_config = json.load(fIn)
81+
82+
if '__version__' in self._model_config and \
83+
'sentence_transformers' in self._model_config['__version__'] and \
84+
self._model_config['__version__']['sentence_transformers'] > sentence_transformers.__version__:
85+
logger.warning("You try to use a model that was created with version {}, "\
86+
"however, your version is {}. This might cause unexpected "\
87+
"behavior or errors. In that case, try to update to the "\
88+
"latest version.\n\n\n".format(
89+
self._model_config['__version__']['sentence_transformers'],
90+
sentence_transformers.__version__))
91+
92+
# Check if a readme exists
93+
model_card_path = sentence_transformers.util.load_file_path(
94+
model_name_or_path, 'README.md', token=token, cache_folder=cache_folder)
95+
if model_card_path is not None:
96+
try:
97+
with open(model_card_path, encoding='utf8') as fIn:
98+
self._model_card_text = fIn.read()
99+
except:
100+
pass
101+
102+
# Load the modules of sentence transformer
103+
modules_json_path = sentence_transformers.util.load_file_path(
104+
model_name_or_path, 'modules.json', token=token, cache_folder=cache_folder)
105+
with open(modules_json_path) as fIn:
106+
modules_config = json.load(fIn)
107+
108+
modules = OrderedDict()
109+
for module_config in modules_config:
110+
module_class = sentence_transformers.util.import_from_string(module_config['type'])
111+
# For Transformer, don't load the full directory, rely on `transformers` instead
112+
# But, do load the config file first.
113+
if module_class == sentence_transformers.models.Transformer and module_config['path'] == "":
114+
kwargs = {}
115+
for config_name in ['sentence_bert_config.json', 'sentence_roberta_config.json',
116+
'sentence_distilbert_config.json', 'sentence_camembert_config.json',
117+
'sentence_albert_config.json', 'sentence_xlm-roberta_config.json',
118+
'sentence_xlnet_config.json']:
119+
config_path = sentence_transformers.util.load_file_path(
120+
model_name_or_path, config_name, token=token, cache_folder=cache_folder)
121+
if config_path is not None:
122+
with open(config_path) as fIn:
123+
kwargs = json.load(fIn)
124+
break
125+
if "model_args" in kwargs:
126+
kwargs["model_args"]["token"] = token
127+
else:
128+
kwargs["model_args"] = {"token": token}
129+
module = sentence_transformers.models.Transformer(
130+
model_name_or_path, cache_dir=cache_folder, **kwargs)
131+
else:
132+
module_path = get_module_path(
133+
model_name_or_path, module_config['path'], token=token, cache_folder=cache_folder)
134+
module = module_class.load(module_path)
135+
modules[module_config['name']] = module
136+
137+
return modules
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# !/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2023 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
import os
19+
from typing import Union, Optional
20+
from intel_extension_for_transformers.transformers.utils.utility import LazyImport
21+
sentence_transformers = LazyImport("sentence_transformers")
22+
23+
def get_module_path(model_name_or_path: str,
24+
path: str,
25+
token: Optional[Union[bool, str]],
26+
cache_folder: Optional[str]):
27+
is_local = os.path.isdir(model_name_or_path)
28+
if is_local:
29+
return os.path.join(model_name_or_path, path)
30+
else:
31+
return sentence_transformers.util.load_dir_path(
32+
model_name_or_path, path, token=token, cache_folder=cache_folder)

intel_extension_for_transformers/neural_chat/tests/ci/api/test_chatbot_build_api.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,26 @@ def test_build_chatbot_with_retrieval_plugin_bge_int8(self):
144144
response = chatbot.predict(query="What is Intel extension for transformers?")
145145
self.assertIsNotNone(response)
146146
plugins.retrieval.enable = False
147+
148+
def test_build_chatbot_with_retrieval_plugin_using_local_file(self):
149+
150+
def _run_retrieval(local_dir):
151+
plugins.retrieval.enable = True
152+
plugins.retrieval.args["input_path"] = "../../../README.md"
153+
plugins.retrieval.args["embedding_model"] = local_dir
154+
pipeline_config = PipelineConfig(model_name_or_path="facebook/opt-125m",
155+
plugins=plugins)
156+
chatbot = build_chatbot(pipeline_config)
157+
self.assertIsNotNone(chatbot)
158+
response = chatbot.predict(query="What is Intel extension for transformers?")
159+
self.assertIsNotNone(response)
160+
161+
# test local file
162+
_run_retrieval(local_dir="/tf_dataset2/inc-ut/gte-base")
163+
_run_retrieval(local_dir="/tf_dataset2/inc-ut/instructor-large")
164+
_run_retrieval(local_dir="/tf_dataset2/inc-ut/bge-base-en-v1.5")
165+
166+
147167

148168
if __name__ == '__main__':
149169
unittest.main()

0 commit comments

Comments
 (0)