Skip to content

Commit c8b5348

Browse files
authored
Merge branch 'main' into vrdn-23/fix-gelu-activation
2 parents d365667 + 02f60f0 commit c8b5348

File tree

7 files changed

+55
-13
lines changed

7 files changed

+55
-13
lines changed

README.md

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,15 @@ To see all options to serve your models:
137137
$ text-embeddings-router --help
138138
Text Embedding Webserver
139139

140-
Usage: text-embeddings-router [OPTIONS]
140+
Usage: text-embeddings-router [OPTIONS] --model-id <MODEL_ID>
141141

142142
Options:
143143
--model-id <MODEL_ID>
144-
The name of the model to load. Can be a MODEL_ID as listed on <https://hf.co/models> like `BAAI/bge-large-en-v1.5`. Or it can be a local directory containing the necessary files as saved by `save_pretrained(...)` methods of transformers
144+
The Hugging Face model ID, can be any model listed on <https://huggingface.co/models> with the `text-embeddings-inference` tag (meaning it's compatible with Text Embeddings Inference).
145+
146+
Alternatively, the specified ID can also be a path to a local directory containing the necessary model files saved by the `save_pretrained(...)` methods of either Transformers or Sentence Transformers.
145147

146148
[env: MODEL_ID=]
147-
[default: BAAI/bge-large-en-v1.5]
148149

149150
--revision <REVISION>
150151
The actual revision of the model if you're referring to a model on the hub. You can use a specific commit id or a branch like `refs/pr/2`
@@ -162,6 +163,11 @@ Options:
162163
[env: DTYPE=]
163164
[possible values: float16, float32]
164165

166+
--served-model-name <SERVED_MODEL_NAME>
167+
The name of the model that is being served. If not specified, defaults to `--model-id`. It is only used for the OpenAI-compatible endpoints via HTTP
168+
169+
[env: SERVED_MODEL_NAME=]
170+
165171
--pooling <POOLING>
166172
Optionally control the pooling method for embedding models.
167173

@@ -238,10 +244,9 @@ Options:
238244

239245
Some embedding models require an extra `Dense` module which contains a single Linear layer and an activation function. By default, those `Dense` modules are stored under the `2_Dense` directory, but there might be cases where different `Dense` modules are provided, to convert the pooled embeddings into different dimensions, available as `2_Dense_<dims>` e.g. https://huggingface.co/NovaSearch/stella_en_400M_v5.
240246

241-
Note that this argument is optional, only required to be set if the path to the `Dense` module is other than `2_Dense`. And it also applies when leveraging the `candle` backend.
247+
Note that this argument is optional, only required to be set if there is no `modules.json` file or when you want to override a single Dense module path, only when running with the `candle` backend.
242248

243249
[env: DENSE_PATH=]
244-
[default: 2_Dense]
245250

246251
--hf-token <HF_TOKEN>
247252
Your Hugging Face Hub token

docs/openapi.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,6 +1215,7 @@
12151215
"required": [
12161216
"model_id",
12171217
"model_dtype",
1218+
"served_model_name",
12181219
"model_type",
12191220
"max_concurrent_requests",
12201221
"max_input_length",
@@ -1278,6 +1279,10 @@
12781279
"model_type": {
12791280
"$ref": "#/components/schemas/ModelType"
12801281
},
1282+
"served_model_name": {
1283+
"type": "string",
1284+
"example": "thenlper/gte-base"
1285+
},
12811286
"sha": {
12821287
"type": "string",
12831288
"example": "null",

docs/source/en/cli_arguments.md

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,15 @@ To see all options to serve your models, run the following:
2222
$ text-embeddings-router --help
2323
Text Embedding Webserver
2424

25-
Usage: text-embeddings-router [OPTIONS]
25+
Usage: text-embeddings-router [OPTIONS] --model-id <MODEL_ID>
2626

2727
Options:
2828
--model-id <MODEL_ID>
29-
The name of the model to load. Can be a MODEL_ID as listed on <https://hf.co/models> like `BAAI/bge-large-en-v1.5`. Or it can be a local directory containing the necessary files as saved by `save_pretrained(...)` methods of transformers
29+
The Hugging Face model ID, can be any model listed on <https://huggingface.co/models> with the `text-embeddings-inference` tag (meaning it's compatible with Text Embeddings Inference).
30+
31+
Alternatively, the specified ID can also be a path to a local directory containing the necessary model files saved by the `save_pretrained(...)` methods of either Transformers or Sentence Transformers.
3032

3133
[env: MODEL_ID=]
32-
[default: BAAI/bge-large-en-v1.5]
3334

3435
--revision <REVISION>
3536
The actual revision of the model if you're referring to a model on the hub. You can use a specific commit id or a branch like `refs/pr/2`
@@ -47,6 +48,11 @@ Options:
4748
[env: DTYPE=]
4849
[possible values: float16, float32]
4950

51+
--served-model-name <SERVED_MODEL_NAME>
52+
The name of the model that is being served. If not specified, defaults to `--model-id`. It is only used for the OpenAI-compatible endpoints via HTTP
53+
54+
[env: SERVED_MODEL_NAME=]
55+
5056
--pooling <POOLING>
5157
Optionally control the pooling method for embedding models.
5258

@@ -123,10 +129,9 @@ Options:
123129

124130
Some embedding models require an extra `Dense` module which contains a single Linear layer and an activation function. By default, those `Dense` modules are stored under the `2_Dense` directory, but there might be cases where different `Dense` modules are provided, to convert the pooled embeddings into different dimensions, available as `2_Dense_<dims>` e.g. https://huggingface.co/NovaSearch/stella_en_400M_v5.
125131

126-
Note that this argument is optional, only required to be set if the path to the `Dense` module is other than `2_Dense`. And it also applies when leveraging the `candle` backend.
132+
Note that this argument is optional, only required to be set if there is no `modules.json` file or when you want to override a single Dense module path, only when running with the `candle` backend.
127133

128134
[env: DENSE_PATH=]
129-
[default: 2_Dense]
130135

131136
--hf-token <HF_TOKEN>
132137
Your Hugging Face Hub token

router/src/http/server.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1153,6 +1153,18 @@ async fn openai_embed(
11531153
span.set_parent(context);
11541154
}
11551155

1156+
// NOTE: Validation of `model` won't fail for the time being given that Text Embeddings
1157+
// Inference can only serve a single model at a time so no need for the `model` parameter to
1158+
// differentiate one model from the other, but we at least raise a warning.
1159+
if let Some(requested_model) = &req.model {
1160+
if requested_model != &info.served_model_name {
1161+
tracing::warn!(
1162+
"The provided `model={}` has not been found, the `model` parameter should be provided either empty or with `model={}` instead.",
1163+
requested_model, info.served_model_name
1164+
);
1165+
}
1166+
}
1167+
11561168
let start_time = Instant::now();
11571169

11581170
let truncate = info.auto_truncate;
@@ -1308,7 +1320,7 @@ async fn openai_embed(
13081320
let response = OpenAICompatResponse {
13091321
object: "list",
13101322
data: embeddings,
1311-
model: info.model_id.clone(),
1323+
model: info.served_model_name.clone(),
13121324
usage: OpenAICompatUsage {
13131325
prompt_tokens: compute_tokens,
13141326
total_tokens: compute_tokens,

router/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ pub async fn run(
4646
revision: Option<String>,
4747
tokenization_workers: Option<usize>,
4848
dtype: Option<DType>,
49+
served_model_name: String,
4950
pooling: Option<text_embeddings_backend::Pool>,
5051
max_concurrent_requests: usize,
5152
max_batch_tokens: usize,
@@ -323,6 +324,7 @@ pub async fn run(
323324
model_id,
324325
model_sha: revision,
325326
model_dtype: dtype.to_string(),
327+
served_model_name,
326328
model_type,
327329
max_concurrent_requests,
328330
max_input_length,
@@ -539,6 +541,8 @@ pub struct Info {
539541
pub model_sha: Option<String>,
540542
#[cfg_attr(feature = "http", schema(example = "float16"))]
541543
pub model_dtype: String,
544+
#[cfg_attr(feature = "http", schema(example = "thenlper/gte-base"))]
545+
pub served_model_name: String,
542546
pub model_type: ModelType,
543547
/// Router Parameters
544548
#[cfg_attr(feature = "http", schema(example = "128"))]

router/src/main.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
1414
struct Args {
1515
/// The Hugging Face model ID, can be any model listed on <https://huggingface.co/models> with
1616
/// the `text-embeddings-inference` tag (meaning it's compatible with Text Embeddings
17-
/// Inference)
17+
/// Inference).
1818
///
1919
/// Alternatively, the specified ID can also be a path to a local directory containing the
2020
/// necessary model files saved by the `save_pretrained(...)` methods of either Transformers or
@@ -38,6 +38,11 @@ struct Args {
3838
#[clap(long, env, value_enum)]
3939
dtype: Option<DType>,
4040

41+
/// The name of the model that is being served. If not specified, defaults to `--model-id`. It
42+
/// is only used for the OpenAI-compatible endpoints via HTTP.
43+
#[clap(long, env)]
44+
served_model_name: Option<String>,
45+
4146
/// Optionally control the pooling method for embedding models.
4247
///
4348
/// If `pooling` is not set, the pooling configuration will be parsed from the
@@ -225,11 +230,16 @@ async fn main() -> Result<()> {
225230
}
226231
let token = args.hf_token.or(args.hf_api_token);
227232

233+
let served_model_name = args
234+
.served_model_name
235+
.unwrap_or_else(|| args.model_id.clone());
236+
228237
text_embeddings_router::run(
229238
args.model_id,
230239
args.revision,
231240
args.tokenization_workers,
232241
args.dtype,
242+
served_model_name,
233243
args.pooling,
234244
args.max_concurrent_requests,
235245
args.max_batch_tokens,

router/tests/common.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,11 @@ async fn check_health(port: u16, timeout: Duration) -> Result<()> {
4646
pub async fn start_server(model_id: String, revision: Option<String>, dtype: DType) -> Result<()> {
4747
let server_task = tokio::spawn({
4848
run(
49-
model_id,
49+
model_id.clone(),
5050
revision,
5151
Some(1),
5252
Some(dtype),
53+
model_id,
5354
None,
5455
4,
5556
1024,

0 commit comments

Comments
 (0)