Skip to content

Commit 1a902e2

Browse files
committed
Add JinaBertClassificationHead
1 parent ea24c2c commit 1a902e2

File tree

1 file changed

+80
-3
lines changed

1 file changed

+80
-3
lines changed

backends/candle/src/models/jina.rs

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -339,11 +339,69 @@ impl JinaBertEncoder {
339339
}
340340
}
341341

342+
pub trait ClassificationHead {
343+
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor>;
344+
}
345+
346+
pub struct JinaBertClassificationHead {
347+
pooler: Option<Linear>,
348+
output: Linear,
349+
span: tracing::Span,
350+
}
351+
352+
impl JinaBertClassificationHead {
353+
pub(crate) fn load(vb: VarBuilder, config: &BertConfig) -> Result<Self> {
354+
let n_classes = match &config.id2label {
355+
None => candle::bail!("`id2label` must be set for classifier models"),
356+
Some(id2label) => id2label.len(),
357+
};
358+
359+
let pooler = if let Ok(pooler_weight) = vb
360+
.pp("bert.pooler.dense")
361+
.get((config.hidden_size, config.hidden_size), "weight")
362+
{
363+
let pooler_bias = vb.pp("bert.pooler.dense").get(config.hidden_size, "bias")?;
364+
Some(Linear::new(pooler_weight, Some(pooler_bias), None))
365+
} else {
366+
None
367+
};
368+
369+
let output_weight = vb
370+
.pp("classifier")
371+
.get((n_classes, config.hidden_size), "weight")?;
372+
let output_bias = vb.pp("classifier").get(n_classes, "bias")?;
373+
let output = Linear::new(output_weight, Some(output_bias), None);
374+
375+
Ok(Self {
376+
pooler,
377+
output,
378+
span: tracing::span!(tracing::Level::TRACE, "classifier"),
379+
})
380+
}
381+
}
382+
383+
impl ClassificationHead for JinaBertClassificationHead {
384+
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
385+
let _enter = self.span.enter();
386+
387+
let mut hidden_states = hidden_states.unsqueeze(1)?;
388+
if let Some(pooler) = self.pooler.as_ref() {
389+
hidden_states = pooler.forward(&hidden_states)?;
390+
hidden_states = hidden_states.tanh()?;
391+
}
392+
393+
let hidden_states = self.output.forward(&hidden_states)?;
394+
let hidden_states = hidden_states.squeeze(1)?;
395+
Ok(hidden_states)
396+
}
397+
}
398+
342399
pub struct JinaBertModel {
343400
embeddings: JinaEmbeddings,
344401
encoder: JinaBertEncoder,
345402
pool: Pool,
346403
alibi: Option<Tensor>,
404+
classifier: Option<Box<dyn ClassificationHead + Send>>,
347405

348406
num_attention_heads: usize,
349407

@@ -366,16 +424,21 @@ impl JinaBertModel {
366424
_ => candle::bail!("not supported"),
367425
};
368426

369-
let pool = match model_type {
370-
ModelType::Classifier => Pool::Cls,
427+
let (pool, classifier) = match model_type {
428+
ModelType::Classifier => {
429+
let pool = Pool::Cls;
430+
let classifier: Box<dyn ClassificationHead + Send> =
431+
Box::new(JinaBertClassificationHead::load(vb.clone(), config)?);
432+
(pool, Some(classifier))
433+
}
371434
ModelType::Embedding(pool) => {
372435
if pool == Pool::Splade {
373436
candle::bail!("`splade` is not supported for Jina")
374437
}
375438
if pool == Pool::LastToken {
376439
candle::bail!("`last_token` is not supported for Jina");
377440
}
378-
pool
441+
(pool, None)
379442
}
380443
};
381444

@@ -401,6 +464,7 @@ impl JinaBertModel {
401464
encoder,
402465
pool,
403466
alibi,
467+
classifier,
404468
num_attention_heads: config.num_attention_heads,
405469
device: vb.device().clone(),
406470
dtype: vb.dtype(),
@@ -665,7 +729,20 @@ impl Model for JinaBertModel {
665729
fn is_padded(&self) -> bool {
666730
true
667731
}
732+
668733
fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
669734
self.forward(batch)
670735
}
736+
737+
fn predict(&self, batch: Batch) -> Result<Tensor> {
738+
match &self.classifier {
739+
None => candle::bail!("`predict` is not implemented for this model"),
740+
Some(classifier) => {
741+
let (pooled_embeddings, _raw_embeddings) = self.forward(batch)?;
742+
let pooled_embeddings =
743+
pooled_embeddings.expect("pooled_embeddings is empty. This is a bug.");
744+
classifier.forward(&pooled_embeddings)
745+
}
746+
}
747+
}
671748
}

0 commit comments

Comments
 (0)