Skip to content

Commit 260965d

Browse files
authored
Support classification head for DistilBERT (#487)
1 parent 3611262 commit 260965d

File tree

1 file changed

+76
-4
lines changed

1 file changed

+76
-4
lines changed

backends/candle/src/models/distilbert.rs

Lines changed: 76 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use crate::models::Model;
33
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
44
use candle_nn::{Embedding, VarBuilder};
55
use serde::Deserialize;
6+
use std::collections::HashMap;
67
use text_embeddings_backend_core::{Batch, ModelType, Pool};
78

89
#[derive(Debug, Clone, PartialEq, Deserialize)]
@@ -16,6 +17,8 @@ pub struct DistilBertConfig {
1617
pub max_position_embeddings: usize,
1718
pub pad_token_id: usize,
1819
pub model_type: Option<String>,
20+
pub classifier_dropout: Option<f64>,
21+
pub id2label: Option<HashMap<String, String>>,
1922
}
2023

2124
#[derive(Debug)]
@@ -318,6 +321,56 @@ impl DistilBertEncoder {
318321
}
319322
}
320323

324+
pub trait ClassificationHead {
325+
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor>;
326+
}
327+
328+
pub struct DistilBertClassificationHead {
329+
pre_classifier: Linear,
330+
classifier: Linear,
331+
span: tracing::Span,
332+
}
333+
334+
impl DistilBertClassificationHead {
335+
pub(crate) fn load(vb: VarBuilder, config: &DistilBertConfig) -> Result<Self> {
336+
let n_classes = match &config.id2label {
337+
None => candle::bail!("`id2label` must be set for classifier models"),
338+
Some(id2label) => id2label.len(),
339+
};
340+
341+
let pre_classifier_weight = vb
342+
.pp("pre_classifier")
343+
.get((config.dim, config.dim), "weight")?;
344+
let pre_classifier_bias = vb.pp("pre_classifier").get(config.dim, "bias")?;
345+
let pre_classifier = Linear::new(pre_classifier_weight, Some(pre_classifier_bias), None);
346+
347+
let classifier_weight = vb.pp("classifier").get((n_classes, config.dim), "weight")?;
348+
let classifier_bias = vb.pp("classifier").get(n_classes, "bias")?;
349+
let classifier = Linear::new(classifier_weight, Some(classifier_bias), None);
350+
351+
Ok(Self {
352+
pre_classifier,
353+
classifier,
354+
span: tracing::span!(tracing::Level::TRACE, "classifier"),
355+
})
356+
}
357+
}
358+
359+
impl ClassificationHead for DistilBertClassificationHead {
360+
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
361+
let _enter = self.span.enter();
362+
363+
let hidden_states = hidden_states.unsqueeze(1)?;
364+
365+
let hidden_states = self.pre_classifier.forward(&hidden_states)?;
366+
let hidden_states = hidden_states.relu()?;
367+
368+
let hidden_states = self.classifier.forward(&hidden_states)?;
369+
let hidden_states = hidden_states.squeeze(1)?;
370+
Ok(hidden_states)
371+
}
372+
}
373+
321374
#[derive(Debug)]
322375
pub struct DistilBertSpladeHead {
323376
vocab_transform: Linear,
@@ -368,11 +421,11 @@ impl DistilBertSpladeHead {
368421
}
369422
}
370423

371-
#[derive(Debug)]
372424
pub struct DistilBertModel {
373425
embeddings: DistilBertEmbeddings,
374426
encoder: DistilBertEncoder,
375427
pool: Pool,
428+
classifier: Option<Box<dyn ClassificationHead + Send>>,
376429
splade: Option<DistilBertSpladeHead>,
377430

378431
num_attention_heads: usize,
@@ -385,15 +438,21 @@ pub struct DistilBertModel {
385438

386439
impl DistilBertModel {
387440
pub fn load(vb: VarBuilder, config: &DistilBertConfig, model_type: ModelType) -> Result<Self> {
388-
let pool = match model_type {
441+
let (pool, classifier) = match model_type {
442+
// Classifier models always use CLS pooling
389443
ModelType::Classifier => {
390-
candle::bail!("`classifier` model type is not supported for DistilBert")
444+
let pool = Pool::Cls;
445+
446+
let classifier: Box<dyn ClassificationHead + Send> =
447+
Box::new(DistilBertClassificationHead::load(vb.clone(), config)?);
448+
(pool, Some(classifier))
391449
}
392450
ModelType::Embedding(pool) => {
393451
if pool == Pool::LastToken {
394452
candle::bail!("`last_token` is not supported for DistilBert");
395453
}
396-
pool
454+
455+
(pool, None)
397456
}
398457
};
399458

@@ -424,6 +483,7 @@ impl DistilBertModel {
424483
embeddings,
425484
encoder,
426485
pool,
486+
classifier,
427487
splade,
428488
num_attention_heads: config.n_heads,
429489
device: vb.device().clone(),
@@ -660,4 +720,16 @@ impl Model for DistilBertModel {
660720
fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
661721
self.forward(batch)
662722
}
723+
724+
fn predict(&self, batch: Batch) -> Result<Tensor> {
725+
match &self.classifier {
726+
None => candle::bail!("`predict` is not implemented for this model"),
727+
Some(classifier) => {
728+
let (pooled_embeddings, _raw_embeddings) = self.forward(batch)?;
729+
let pooled_embeddings =
730+
pooled_embeddings.expect("pooled_embeddings is empty. This is a bug.");
731+
classifier.forward(&pooled_embeddings)
732+
}
733+
}
734+
}
663735
}

0 commit comments

Comments
 (0)