Skip to content

Commit f578c1a

Browse files
author
Anton Tcholakov
authored
feat: add support for "model_type": "gte" (#519)
1 parent 11f4893 commit f578c1a

7 files changed

+3142
-12
lines changed

README.md

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,17 +71,19 @@ Below are some examples of the currently supported models:
7171

7272
| MTEB Rank | Model Size | Model Type | Model ID |
7373
|-----------|---------------------|-------------|--------------------------------------------------------------------------------------------------|
74-
| 1 | 7B (Very Expensive) | Mistral | [Salesforce/SFR-Embedding-2_R](https://hf.co/Salesforce/SFR-Embedding-2_R) |
75-
| 2 | 7B (Very Expensive) | Qwen2 | [Alibaba-NLP/gte-Qwen2-7B-instruct](https://hf.co/Alibaba-NLP/gte-Qwen2-7B-instruct) |
76-
| 9 | 1.5B (Expensive) | Qwen2 | [Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://hf.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct) |
77-
| 15 | 0.4B | Alibaba GTE | [Alibaba-NLP/gte-large-en-v1.5](https://hf.co/Alibaba-NLP/gte-large-en-v1.5) |
74+
| 3 | 7B (Very Expensive) | Qwen2 | [Alibaba-NLP/gte-Qwen2-7B-instruct](https://hf.co/Alibaba-NLP/gte-Qwen2-7B-instruct) |
75+
| 11 | 1.5B (Expensive) | Qwen2 | [Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://hf.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct) |
76+
| 14 | 7B (Very Expensive) | Mistral | [Salesforce/SFR-Embedding-2_R](https://hf.co/Salesforce/SFR-Embedding-2_R) |
7877
| 20 | 0.3B | Bert | [WhereIsAI/UAE-Large-V1](https://hf.co/WhereIsAI/UAE-Large-V1) |
79-
| 24 | 0.5B | XLM-RoBERTa | [intfloat/multilingual-e5-large-instruct](https://hf.co/intfloat/multilingual-e5-large-instruct) |
78+
| 31 | 0.5B | XLM-RoBERTa | [Snowflake/snowflake-arctic-embed-l-v2.0](https://hf.co/Snowflake/snowflake-arctic-embed-l-v2.0) |
79+
| 37 | 0.3B | Alibaba GTE | [Snowflake/snowflake-arctic-embed-m-v2.0](https://hf.co/Snowflake/snowflake-arctic-embed-m-v2.0) |
80+
| 49 | 0.5B | XLM-RoBERTa | [intfloat/multilingual-e5-large-instruct](https://hf.co/intfloat/multilingual-e5-large-instruct) |
81+
| N/A | 0.4B | Alibaba GTE | [Alibaba-NLP/gte-large-en-v1.5](https://hf.co/Alibaba-NLP/gte-large-en-v1.5) |
8082
| N/A | 0.1B | NomicBert | [nomic-ai/nomic-embed-text-v1](https://hf.co/nomic-ai/nomic-embed-text-v1) |
8183
| N/A | 0.1B | NomicBert | [nomic-ai/nomic-embed-text-v1.5](https://hf.co/nomic-ai/nomic-embed-text-v1.5) |
8284
| N/A | 0.1B | JinaBERT | [jinaai/jina-embeddings-v2-base-en](https://hf.co/jinaai/jina-embeddings-v2-base-en) |
8385
| N/A | 0.1B | JinaBERT | [jinaai/jina-embeddings-v2-base-code](https://hf.co/jinaai/jina-embeddings-v2-base-code) |
84-
| N/A | 0.1B | MPNet | [sentence-transformers/all-mpnet-base-v2](https://hf.co/sentence-transformers/all-mpnet-base-v2) |
86+
| N/A | 0.1B | MPNet | [sentence-transformers/all-mpnet-base-v2](https://hf.co/sentence-transformers/all-mpnet-base-v2) |
8587

8688
To explore the list of best performing text embeddings models, visit the
8789
[Massive Text Embedding Benchmark (MTEB) Leaderboard](https://huggingface.co/spaces/mteb/leaderboard).

backends/candle/src/lib.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,9 @@ enum Config {
5959
NomicBert(NomicConfig),
6060
#[allow(dead_code)]
6161
Mistral(MistralConfig),
62-
#[serde(rename = "new")]
6362
Gte(GTEConfig),
63+
#[serde(rename = "new")]
64+
GteAlibaba(GTEConfig),
6465
#[allow(dead_code)]
6566
Qwen2(Qwen2Config),
6667
#[serde(rename = "mpnet")]
@@ -223,7 +224,7 @@ impl CandleBackend {
223224
"Mistral is only supported on Cuda devices in fp16 with flash attention enabled"
224225
.to_string(),
225226
)),
226-
(Config::Gte(config), Device::Cpu | Device::Metal(_)) => {
227+
(Config::Gte(config) | Config::GteAlibaba(config), Device::Cpu | Device::Metal(_)) => {
227228
tracing::info!("Starting GTE model on {:?}", device);
228229
Ok(Box::new(GTEModel::load(vb, &config, model_type).s()?))
229230
}
@@ -354,7 +355,7 @@ impl CandleBackend {
354355
))
355356
}
356357
#[cfg(feature = "cuda")]
357-
(Config::Gte(config), Device::Cuda(_)) => {
358+
(Config::Gte(config) | Config::GteAlibaba(config), Device::Cuda(_)) => {
358359
if dtype != DType::F16
359360
|| !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
360361
{

0 commit comments

Comments
 (0)