@@ -3,6 +3,7 @@ use crate::models::Model;
3
3
use candle:: { DType , Device , IndexOp , Module , Result , Tensor , D } ;
4
4
use candle_nn:: { Embedding , VarBuilder } ;
5
5
use serde:: Deserialize ;
6
+ use std:: collections:: HashMap ;
6
7
use text_embeddings_backend_core:: { Batch , ModelType , Pool } ;
7
8
8
9
#[ derive( Debug , Clone , PartialEq , Deserialize ) ]
@@ -16,6 +17,8 @@ pub struct DistilBertConfig {
16
17
pub max_position_embeddings : usize ,
17
18
pub pad_token_id : usize ,
18
19
pub model_type : Option < String > ,
20
+ pub classifier_dropout : Option < f64 > ,
21
+ pub id2label : Option < HashMap < String , String > > ,
19
22
}
20
23
21
24
#[ derive( Debug ) ]
@@ -318,6 +321,56 @@ impl DistilBertEncoder {
318
321
}
319
322
}
320
323
324
+ pub trait ClassificationHead {
325
+ fn forward ( & self , hidden_states : & Tensor ) -> Result < Tensor > ;
326
+ }
327
+
328
+ pub struct DistilBertClassificationHead {
329
+ pre_classifier : Linear ,
330
+ classifier : Linear ,
331
+ span : tracing:: Span ,
332
+ }
333
+
334
+ impl DistilBertClassificationHead {
335
+ pub ( crate ) fn load ( vb : VarBuilder , config : & DistilBertConfig ) -> Result < Self > {
336
+ let n_classes = match & config. id2label {
337
+ None => candle:: bail!( "`id2label` must be set for classifier models" ) ,
338
+ Some ( id2label) => id2label. len ( ) ,
339
+ } ;
340
+
341
+ let pre_classifier_weight = vb
342
+ . pp ( "pre_classifier" )
343
+ . get ( ( config. dim , config. dim ) , "weight" ) ?;
344
+ let pre_classifier_bias = vb. pp ( "pre_classifier" ) . get ( config. dim , "bias" ) ?;
345
+ let pre_classifier = Linear :: new ( pre_classifier_weight, Some ( pre_classifier_bias) , None ) ;
346
+
347
+ let classifier_weight = vb. pp ( "classifier" ) . get ( ( n_classes, config. dim ) , "weight" ) ?;
348
+ let classifier_bias = vb. pp ( "classifier" ) . get ( n_classes, "bias" ) ?;
349
+ let classifier = Linear :: new ( classifier_weight, Some ( classifier_bias) , None ) ;
350
+
351
+ Ok ( Self {
352
+ pre_classifier,
353
+ classifier,
354
+ span : tracing:: span!( tracing:: Level :: TRACE , "classifier" ) ,
355
+ } )
356
+ }
357
+ }
358
+
359
+ impl ClassificationHead for DistilBertClassificationHead {
360
+ fn forward ( & self , hidden_states : & Tensor ) -> Result < Tensor > {
361
+ let _enter = self . span . enter ( ) ;
362
+
363
+ let hidden_states = hidden_states. unsqueeze ( 1 ) ?;
364
+
365
+ let hidden_states = self . pre_classifier . forward ( & hidden_states) ?;
366
+ let hidden_states = hidden_states. relu ( ) ?;
367
+
368
+ let hidden_states = self . classifier . forward ( & hidden_states) ?;
369
+ let hidden_states = hidden_states. squeeze ( 1 ) ?;
370
+ Ok ( hidden_states)
371
+ }
372
+ }
373
+
321
374
#[ derive( Debug ) ]
322
375
pub struct DistilBertSpladeHead {
323
376
vocab_transform : Linear ,
@@ -368,11 +421,11 @@ impl DistilBertSpladeHead {
368
421
}
369
422
}
370
423
371
- #[ derive( Debug ) ]
372
424
pub struct DistilBertModel {
373
425
embeddings : DistilBertEmbeddings ,
374
426
encoder : DistilBertEncoder ,
375
427
pool : Pool ,
428
+ classifier : Option < Box < dyn ClassificationHead + Send > > ,
376
429
splade : Option < DistilBertSpladeHead > ,
377
430
378
431
num_attention_heads : usize ,
@@ -385,15 +438,21 @@ pub struct DistilBertModel {
385
438
386
439
impl DistilBertModel {
387
440
pub fn load ( vb : VarBuilder , config : & DistilBertConfig , model_type : ModelType ) -> Result < Self > {
388
- let pool = match model_type {
441
+ let ( pool, classifier) = match model_type {
442
+ // Classifier models always use CLS pooling
389
443
ModelType :: Classifier => {
390
- candle:: bail!( "`classifier` model type is not supported for DistilBert" )
444
+ let pool = Pool :: Cls ;
445
+
446
+ let classifier: Box < dyn ClassificationHead + Send > =
447
+ Box :: new ( DistilBertClassificationHead :: load ( vb. clone ( ) , config) ?) ;
448
+ ( pool, Some ( classifier) )
391
449
}
392
450
ModelType :: Embedding ( pool) => {
393
451
if pool == Pool :: LastToken {
394
452
candle:: bail!( "`last_token` is not supported for DistilBert" ) ;
395
453
}
396
- pool
454
+
455
+ ( pool, None )
397
456
}
398
457
} ;
399
458
@@ -424,6 +483,7 @@ impl DistilBertModel {
424
483
embeddings,
425
484
encoder,
426
485
pool,
486
+ classifier,
427
487
splade,
428
488
num_attention_heads : config. n_heads ,
429
489
device : vb. device ( ) . clone ( ) ,
@@ -660,4 +720,16 @@ impl Model for DistilBertModel {
660
720
fn embed ( & self , batch : Batch ) -> Result < ( Option < Tensor > , Option < Tensor > ) > {
661
721
self . forward ( batch)
662
722
}
723
+
724
+ fn predict ( & self , batch : Batch ) -> Result < Tensor > {
725
+ match & self . classifier {
726
+ None => candle:: bail!( "`predict` is not implemented for this model" ) ,
727
+ Some ( classifier) => {
728
+ let ( pooled_embeddings, _raw_embeddings) = self . forward ( batch) ?;
729
+ let pooled_embeddings =
730
+ pooled_embeddings. expect ( "pooled_embeddings is empty. This is a bug." ) ;
731
+ classifier. forward ( & pooled_embeddings)
732
+ }
733
+ }
734
+ }
663
735
}
0 commit comments