Skip to content

Commit d8021c3

Browse files
authored
Enable flash mistral model for HPU device (#594)
Signed-off-by: Liu, Kaixuan <[email protected]>
1 parent 245a244 commit d8021c3

File tree

3 files changed

+456
-2
lines changed

3 files changed

+456
-2
lines changed

backends/python/server/text_embeddings_server/models/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from text_embeddings_server.models.masked_model import MaskedLanguageModel
1212
from text_embeddings_server.models.default_model import DefaultModel
1313
from text_embeddings_server.models.classification_model import ClassificationModel
14+
from text_embeddings_server.models.flash_mistral import FlashMistral
1415
from text_embeddings_server.utils.device import get_device, use_ipex
1516

1617
__all__ = ["Model"]
@@ -89,6 +90,22 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
8990
pool,
9091
trust_remote=TRUST_REMOTE_CODE,
9192
)
93+
elif config.model_type == "mistral" and device.type == "hpu":
94+
try:
95+
return FlashMistral(
96+
model_path,
97+
device,
98+
datatype,
99+
pool,
100+
)
101+
except FileNotFoundError as e:
102+
return DefaultModel(
103+
model_path,
104+
device,
105+
datatype,
106+
pool,
107+
trust_remote=TRUST_REMOTE_CODE,
108+
)
92109
else:
93110
if device.type == "hpu":
94111
from habana_frameworks.torch.hpu import wrap_in_hpu_graph

0 commit comments

Comments
 (0)