@@ -454,28 +454,10 @@ def forward(
454
454
return hidden_states
455
455
456
456
457
- class JinaBertPooler :
458
- def __init__ (self , handle , device , dtype , config ):
459
- self .dense_weight = (
460
- handle .get_tensor (f"pooler.dense.weight" ).to (dtype ).to (device )
461
- )
462
- self .dense_bias = handle .get_tensor (f"pooler.dense.bias" ).to (dtype ).to (device )
463
- self .activation = nn .Tanh ()
464
-
465
- def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
466
- # We "pool" the model by simply taking the hidden state corresponding
467
- # to the first token.
468
- first_token_tensor = hidden_states [:, 0 ]
469
- pooled_output = F .linear (first_token_tensor , self .dense_weight , self .dense_bias )
470
- pooled_output = self .activation (pooled_output )
471
- return pooled_output
472
-
473
-
474
457
class FlashJinaBertModel :
475
458
def __init__ (self , handle , device , dtype , config : AutoConfig ):
476
459
self .embeddings = JinaBertEmbeddings (handle , device , dtype , config )
477
460
self .encoder = JinaBertEncoder (handle , device , dtype , config )
478
- self .pooler = JinaBertPooler (handle , device , dtype , config )
479
461
480
462
def forward (
481
463
self ,
@@ -486,8 +468,7 @@ def forward(
486
468
):
487
469
embeddings = self .embeddings .forward (input_ids , token_type_ids , position_ids )
488
470
encoder_outputs = self .encoder .forward (embeddings , attn_mask )
489
- pooled_output = self .pooler (encoder_outputs )
490
- return pooled_output
471
+ return encoder_outputs
491
472
492
473
493
474
class FlashJinaBert (Model ):
@@ -522,13 +503,24 @@ def __init__(
522
503
def batch_type (self ) -> Type [PaddedBatch ]:
523
504
return PaddedBatch
524
505
506
+ def mean_pooling (
507
+ self , token_embeddings : torch .Tensor , attention_mask : torch .Tensor
508
+ ):
509
+ input_mask_expanded = (
510
+ attention_mask .unsqueeze (- 1 ).expand (token_embeddings .size ()).float ()
511
+ )
512
+ return torch .sum (token_embeddings * input_mask_expanded , 1 ) / torch .clamp (
513
+ input_mask_expanded .sum (1 ), min = 1e-9
514
+ )
515
+
525
516
@tracer .start_as_current_span ("embed" )
526
517
def embed (self , batch : PaddedBatch ) -> List [Embedding ]:
527
- kwargs = {"input_ids" : batch .input_ids , "attn_mask" : batch . attention_mask }
518
+ kwargs = {"input_ids" : batch .input_ids }
528
519
kwargs ["token_type_ids" ] = batch .token_type_ids
529
520
kwargs ["position_ids" ] = batch .position_ids
530
- embedding = self .model .forward (** kwargs )
521
+ outputs = self .model .forward (** kwargs )
531
522
523
+ embedding = self .mean_pooling (outputs , batch .attention_mask )
532
524
cpu_results = embedding .view (- 1 ).tolist ()
533
525
534
526
return [
0 commit comments