@@ -339,11 +339,69 @@ impl JinaBertEncoder {
339
339
}
340
340
}
341
341
342
+ pub trait ClassificationHead {
343
+ fn forward ( & self , hidden_states : & Tensor ) -> Result < Tensor > ;
344
+ }
345
+
346
+ pub struct JinaBertClassificationHead {
347
+ pooler : Option < Linear > ,
348
+ output : Linear ,
349
+ span : tracing:: Span ,
350
+ }
351
+
352
+ impl JinaBertClassificationHead {
353
+ pub ( crate ) fn load ( vb : VarBuilder , config : & BertConfig ) -> Result < Self > {
354
+ let n_classes = match & config. id2label {
355
+ None => candle:: bail!( "`id2label` must be set for classifier models" ) ,
356
+ Some ( id2label) => id2label. len ( ) ,
357
+ } ;
358
+
359
+ let pooler = if let Ok ( pooler_weight) = vb
360
+ . pp ( "bert.pooler.dense" )
361
+ . get ( ( config. hidden_size , config. hidden_size ) , "weight" )
362
+ {
363
+ let pooler_bias = vb. pp ( "bert.pooler.dense" ) . get ( config. hidden_size , "bias" ) ?;
364
+ Some ( Linear :: new ( pooler_weight, Some ( pooler_bias) , None ) )
365
+ } else {
366
+ None
367
+ } ;
368
+
369
+ let output_weight = vb
370
+ . pp ( "classifier" )
371
+ . get ( ( n_classes, config. hidden_size ) , "weight" ) ?;
372
+ let output_bias = vb. pp ( "classifier" ) . get ( n_classes, "bias" ) ?;
373
+ let output = Linear :: new ( output_weight, Some ( output_bias) , None ) ;
374
+
375
+ Ok ( Self {
376
+ pooler,
377
+ output,
378
+ span : tracing:: span!( tracing:: Level :: TRACE , "classifier" ) ,
379
+ } )
380
+ }
381
+ }
382
+
383
+ impl ClassificationHead for JinaBertClassificationHead {
384
+ fn forward ( & self , hidden_states : & Tensor ) -> Result < Tensor > {
385
+ let _enter = self . span . enter ( ) ;
386
+
387
+ let mut hidden_states = hidden_states. unsqueeze ( 1 ) ?;
388
+ if let Some ( pooler) = self . pooler . as_ref ( ) {
389
+ hidden_states = pooler. forward ( & hidden_states) ?;
390
+ hidden_states = hidden_states. tanh ( ) ?;
391
+ }
392
+
393
+ let hidden_states = self . output . forward ( & hidden_states) ?;
394
+ let hidden_states = hidden_states. squeeze ( 1 ) ?;
395
+ Ok ( hidden_states)
396
+ }
397
+ }
398
+
342
399
pub struct JinaBertModel {
343
400
embeddings : JinaEmbeddings ,
344
401
encoder : JinaBertEncoder ,
345
402
pool : Pool ,
346
403
alibi : Option < Tensor > ,
404
+ classifier : Option < Box < dyn ClassificationHead + Send > > ,
347
405
348
406
num_attention_heads : usize ,
349
407
@@ -366,16 +424,21 @@ impl JinaBertModel {
366
424
_ => candle:: bail!( "not supported" ) ,
367
425
} ;
368
426
369
- let pool = match model_type {
370
- ModelType :: Classifier => Pool :: Cls ,
427
+ let ( pool, classifier) = match model_type {
428
+ ModelType :: Classifier => {
429
+ let pool = Pool :: Cls ;
430
+ let classifier: Box < dyn ClassificationHead + Send > =
431
+ Box :: new ( JinaBertClassificationHead :: load ( vb. clone ( ) , config) ?) ;
432
+ ( pool, Some ( classifier) )
433
+ }
371
434
ModelType :: Embedding ( pool) => {
372
435
if pool == Pool :: Splade {
373
436
candle:: bail!( "`splade` is not supported for Jina" )
374
437
}
375
438
if pool == Pool :: LastToken {
376
439
candle:: bail!( "`last_token` is not supported for Jina" ) ;
377
440
}
378
- pool
441
+ ( pool, None )
379
442
}
380
443
} ;
381
444
@@ -401,6 +464,7 @@ impl JinaBertModel {
401
464
encoder,
402
465
pool,
403
466
alibi,
467
+ classifier,
404
468
num_attention_heads : config. num_attention_heads ,
405
469
device : vb. device ( ) . clone ( ) ,
406
470
dtype : vb. dtype ( ) ,
@@ -665,7 +729,20 @@ impl Model for JinaBertModel {
665
729
fn is_padded ( & self ) -> bool {
666
730
true
667
731
}
732
+
668
733
fn embed ( & self , batch : Batch ) -> Result < ( Option < Tensor > , Option < Tensor > ) > {
669
734
self . forward ( batch)
670
735
}
736
+
737
+ fn predict ( & self , batch : Batch ) -> Result < Tensor > {
738
+ match & self . classifier {
739
+ None => candle:: bail!( "`predict` is not implemented for this model" ) ,
740
+ Some ( classifier) => {
741
+ let ( pooled_embeddings, _raw_embeddings) = self . forward ( batch) ?;
742
+ let pooled_embeddings =
743
+ pooled_embeddings. expect ( "pooled_embeddings is empty. This is a bug." ) ;
744
+ classifier. forward ( & pooled_embeddings)
745
+ }
746
+ }
747
+ }
671
748
}
0 commit comments