Skip to content

Commit 251d7d3

Browse files
authored
Add support for JinaAI Re-Rankers V1 (#582)
1 parent 7a7e5fc commit 251d7d3

File tree

7 files changed

+140
-13
lines changed

7 files changed

+140
-13
lines changed

backends/candle/src/models/flash_jina.rs

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::alibi::alibi_head_slopes;
22
use crate::flash_attn::flash_attn_varlen;
33
use crate::layers::{HiddenAct, LayerNorm, Linear};
44
use crate::models::bert::PositionEmbeddingType;
5-
use crate::models::jina::JinaEmbeddings;
5+
use crate::models::jina::{ClassificationHead, JinaBertClassificationHead, JinaEmbeddings};
66
use crate::models::{BertConfig, Model};
77
use candle::{DType, Device, IndexOp, Result, Tensor};
88
use candle_nn::VarBuilder;
@@ -227,6 +227,8 @@ pub struct FlashJinaBertModel {
227227
embeddings: JinaEmbeddings,
228228
encoder: JinaBertEncoder,
229229
pool: Pool,
230+
classifier: Option<Box<dyn ClassificationHead + Send>>,
231+
230232
pub device: Device,
231233

232234
span: tracing::Span,
@@ -255,15 +257,19 @@ impl FlashJinaBertModel {
255257
candle::bail!("FlashJinaBertModel requires DType::F16")
256258
}
257259

258-
let pool = match model_type {
260+
let (pool, classifier) = match model_type {
259261
ModelType::Classifier => {
260-
candle::bail!("`classifier` model type is not supported for Jina")
262+
let pool = Pool::Cls;
263+
264+
let classifier: Box<dyn ClassificationHead + Send> =
265+
Box::new(JinaBertClassificationHead::load(vb.clone(), config)?);
266+
(pool, Some(classifier))
261267
}
262268
ModelType::Embedding(pool) => {
263269
if pool == Pool::Splade {
264270
candle::bail!("`splade` is not supported for Jina")
265271
}
266-
pool
272+
(pool, None)
267273
}
268274
};
269275

@@ -288,6 +294,7 @@ impl FlashJinaBertModel {
288294
embeddings,
289295
encoder,
290296
pool,
297+
classifier,
291298
device: vb.device().clone(),
292299
span: tracing::span!(tracing::Level::TRACE, "model"),
293300
})
@@ -433,7 +440,20 @@ impl Model for FlashJinaBertModel {
433440
fn is_padded(&self) -> bool {
434441
false
435442
}
443+
436444
fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
437445
self.forward(batch)
438446
}
447+
448+
fn predict(&self, batch: Batch) -> Result<Tensor> {
449+
match &self.classifier {
450+
None => candle::bail!("`predict` is not implemented for this model"),
451+
Some(classifier) => {
452+
let (pooled_embeddings, _raw_embeddings) = self.forward(batch)?;
453+
let pooled_embeddings =
454+
pooled_embeddings.expect("pooled_embeddings is empty. This is a bug.");
455+
classifier.forward(&pooled_embeddings)
456+
}
457+
}
458+
}
439459
}

backends/candle/src/models/jina.rs

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -339,11 +339,69 @@ impl JinaBertEncoder {
339339
}
340340
}
341341

342+
pub trait ClassificationHead {
343+
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor>;
344+
}
345+
346+
pub struct JinaBertClassificationHead {
347+
pooler: Option<Linear>,
348+
output: Linear,
349+
span: tracing::Span,
350+
}
351+
352+
impl JinaBertClassificationHead {
353+
pub(crate) fn load(vb: VarBuilder, config: &BertConfig) -> Result<Self> {
354+
let n_classes = match &config.id2label {
355+
None => candle::bail!("`id2label` must be set for classifier models"),
356+
Some(id2label) => id2label.len(),
357+
};
358+
359+
let pooler = if let Ok(pooler_weight) = vb
360+
.pp("bert.pooler.dense")
361+
.get((config.hidden_size, config.hidden_size), "weight")
362+
{
363+
let pooler_bias = vb.pp("bert.pooler.dense").get(config.hidden_size, "bias")?;
364+
Some(Linear::new(pooler_weight, Some(pooler_bias), None))
365+
} else {
366+
None
367+
};
368+
369+
let output_weight = vb
370+
.pp("classifier")
371+
.get((n_classes, config.hidden_size), "weight")?;
372+
let output_bias = vb.pp("classifier").get(n_classes, "bias")?;
373+
let output = Linear::new(output_weight, Some(output_bias), None);
374+
375+
Ok(Self {
376+
pooler,
377+
output,
378+
span: tracing::span!(tracing::Level::TRACE, "classifier"),
379+
})
380+
}
381+
}
382+
383+
impl ClassificationHead for JinaBertClassificationHead {
384+
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
385+
let _enter = self.span.enter();
386+
387+
let mut hidden_states = hidden_states.unsqueeze(1)?;
388+
if let Some(pooler) = self.pooler.as_ref() {
389+
hidden_states = pooler.forward(&hidden_states)?;
390+
hidden_states = hidden_states.tanh()?;
391+
}
392+
393+
let hidden_states = self.output.forward(&hidden_states)?;
394+
let hidden_states = hidden_states.squeeze(1)?;
395+
Ok(hidden_states)
396+
}
397+
}
398+
342399
pub struct JinaBertModel {
343400
embeddings: JinaEmbeddings,
344401
encoder: JinaBertEncoder,
345402
pool: Pool,
346403
alibi: Option<Tensor>,
404+
classifier: Option<Box<dyn ClassificationHead + Send>>,
347405

348406
num_attention_heads: usize,
349407

@@ -366,9 +424,12 @@ impl JinaBertModel {
366424
_ => candle::bail!("not supported"),
367425
};
368426

369-
let pool = match model_type {
427+
let (pool, classifier) = match model_type {
370428
ModelType::Classifier => {
371-
candle::bail!("`classifier` model type is not supported for Jina")
429+
let pool = Pool::Cls;
430+
let classifier: Box<dyn ClassificationHead + Send> =
431+
Box::new(JinaBertClassificationHead::load(vb.clone(), config)?);
432+
(pool, Some(classifier))
372433
}
373434
ModelType::Embedding(pool) => {
374435
if pool == Pool::Splade {
@@ -377,7 +438,7 @@ impl JinaBertModel {
377438
if pool == Pool::LastToken {
378439
candle::bail!("`last_token` is not supported for Jina");
379440
}
380-
pool
441+
(pool, None)
381442
}
382443
};
383444

@@ -403,6 +464,7 @@ impl JinaBertModel {
403464
encoder,
404465
pool,
405466
alibi,
467+
classifier,
406468
num_attention_heads: config.num_attention_heads,
407469
device: vb.device().clone(),
408470
dtype: vb.dtype(),
@@ -667,7 +729,20 @@ impl Model for JinaBertModel {
667729
fn is_padded(&self) -> bool {
668730
true
669731
}
732+
670733
fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
671734
self.forward(batch)
672735
}
736+
737+
fn predict(&self, batch: Batch) -> Result<Tensor> {
738+
match &self.classifier {
739+
None => candle::bail!("`predict` is not implemented for this model"),
740+
Some(classifier) => {
741+
let (pooled_embeddings, _raw_embeddings) = self.forward(batch)?;
742+
let pooled_embeddings =
743+
pooled_embeddings.expect("pooled_embeddings is empty. This is a bug.");
744+
classifier.forward(&pooled_embeddings)
745+
}
746+
}
747+
}
673748
}

backends/candle/src/models/jina_code.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,7 @@ impl Model for JinaCodeBertModel {
656656
fn is_padded(&self) -> bool {
657657
true
658658
}
659+
659660
fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
660661
self.forward(batch)
661662
}

backends/candle/src/models/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,6 @@ pub(crate) trait Model {
9292
}
9393

9494
fn predict(&self, _batch: Batch) -> Result<Tensor> {
95-
candle::bail!("`predict is not implemented for this model");
95+
candle::bail!("`predict` is not implemented for this model");
9696
}
9797
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
source: backends/candle/tests/test_jina.rs
3+
expression: predictions
4+
---
5+
- - -0.6045344

backends/candle/tests/test_jina.rs

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
mod common;
22

3-
use crate::common::{sort_embeddings, SnapshotEmbeddings};
3+
use crate::common::{sort_embeddings, SnapshotEmbeddings, SnapshotScores};
44
use anyhow::Result;
5-
use common::{batch, cosine_matcher, download_artifacts, load_tokenizer};
5+
use common::{batch, cosine_matcher, download_artifacts, load_tokenizer, relative_matcher};
66
use text_embeddings_backend_candle::CandleBackend;
77
use text_embeddings_backend_core::{Backend, ModelType, Pool};
88

@@ -48,3 +48,29 @@ fn test_jina_small() -> Result<()> {
4848

4949
Ok(())
5050
}
51+
52+
#[test]
53+
#[serial_test::serial]
54+
fn test_jina_rerank() -> Result<()> {
55+
let model_root = download_artifacts("jinaai/jina-reranker-v1-tiny-en", Some("refs/pr/11"))?;
56+
let tokenizer = load_tokenizer(&model_root)?;
57+
58+
let backend = CandleBackend::new(&model_root, "float32".to_string(), ModelType::Classifier)?;
59+
60+
let input_single = batch(
61+
vec![tokenizer.encode("What is Deep Learning?", true).unwrap()],
62+
[0].to_vec(),
63+
vec![],
64+
);
65+
66+
let predictions: Vec<Vec<f32>> = backend
67+
.predict(input_single)?
68+
.into_iter()
69+
.map(|(_, v)| v)
70+
.collect();
71+
72+
let predictions = SnapshotScores::from(predictions);
73+
insta::assert_yaml_snapshot!("jinabert_reranker_single", predictions, &relative_matcher());
74+
75+
Ok(())
76+
}

backends/ort/src/lib.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ impl OrtBackend {
2121
model_type: ModelType,
2222
) -> Result<Self, BackendError> {
2323
// Check dtype
24-
if dtype == "float32" {
25-
} else {
24+
if dtype != "float32" {
2625
return Err(BackendError::Start(format!(
2726
"DType {dtype} is not supported"
2827
)));
@@ -167,8 +166,8 @@ impl Backend for OrtBackend {
167166

168167
// Run model
169168
let outputs = self.session.run(inputs).e()?;
170-
// Get last_hidden_state ndarray
171169

170+
// Get last_hidden_state ndarray
172171
let outputs = outputs
173172
.get("last_hidden_state")
174173
.or(outputs.get("token_embeddings"))
@@ -362,6 +361,7 @@ impl Backend for OrtBackend {
362361

363362
// Run model
364363
let outputs = self.session.run(inputs).e()?;
364+
365365
// Get last_hidden_state ndarray
366366
let outputs = outputs["logits"]
367367
.try_extract_tensor::<f32>()

0 commit comments

Comments
 (0)