@@ -108,7 +108,7 @@ def forward(
108108 self .layernorm_weight .shape ,
109109 self .layernorm_weight ,
110110 self .layernorm_bias ,
111- eps = self .config .layer_norm_eps
111+ eps = self .config .layer_norm_eps ,
112112 )
113113 embeddings = self .dropout (embeddings )
114114 return embeddings
@@ -174,11 +174,12 @@ def forward(
174174 self .layer_norm_q_weight .shape ,
175175 self .layer_norm_q_weight ,
176176 self .layer_norm_q_bias ,
177- eps = self .config .layer_norm_eps ,)
177+ eps = self .config .layer_norm_eps ,
178+ )
178179
179180 k_hidden_states = F .linear (hidden_states , self .key_weight , self .key_bias )
180181 key_layer = self .transpose_for_scores (
181- F .layer_norm (
182+ F .layer_norm (
182183 k_hidden_states ,
183184 self .layer_norm_k_weight .shape ,
184185 self .layer_norm_k_weight ,
@@ -237,7 +238,7 @@ def forward(
237238 hidden_states = F .linear (hidden_states , self .dense_weight , self .dense_bias )
238239 hidden_states = self .dropout (hidden_states )
239240 hidden_states = F .layer_norm (
240- hidden_states + input_tensor ,
241+ hidden_states + input_tensor ,
241242 self .layerNorm_weight .shape ,
242243 self .layerNorm_weight ,
243244 self .layerNorm_bias ,
@@ -453,10 +454,28 @@ def forward(
453454 return hidden_states
454455
455456
457+ class JinaBertPooler :
458+ def __init__ (self , handle , device , dtype , config ):
459+ self .dense_weight = (
460+ handle .get_tensor (f"pooler.dense.weight" ).to (dtype ).to (device )
461+ )
462+ self .dense_bias = handle .get_tensor (f"pooler.dense.bias" ).to (dtype ).to (device )
463+ self .activation = nn .Tanh ()
464+
465+ def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
466+ # We "pool" the model by simply taking the hidden state corresponding
467+ # to the first token.
468+ first_token_tensor = hidden_states [:, 0 ]
469+ pooled_output = F .linear (first_token_tensor , self .dense_weight , self .dense_bias )
470+ pooled_output = self .activation (pooled_output )
471+ return pooled_output
472+
473+
456474class FlashJinaBertModel :
457475 def __init__ (self , handle , device , dtype , config : AutoConfig ):
458476 self .embeddings = JinaBertEmbeddings (handle , device , dtype , config )
459477 self .encoder = JinaBertEncoder (handle , device , dtype , config )
478+ self .pooler = JinaBertPooler (handle , device , dtype , config )
460479
461480 def forward (
462481 self ,
@@ -467,7 +486,8 @@ def forward(
467486 ):
468487 embeddings = self .embeddings .forward (input_ids , token_type_ids , position_ids )
469488 encoder_outputs = self .encoder .forward (embeddings , attn_mask )
470- return encoder_outputs [0 ]
489+ pooled_output = self .pooler (encoder_outputs )
490+ return pooled_output
471491
472492
473493class FlashJinaBert (Model ):
@@ -507,9 +527,7 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]:
507527 kwargs = {"input_ids" : batch .input_ids , "attn_mask" : batch .attention_mask }
508528 kwargs ["token_type_ids" ] = batch .token_type_ids
509529 kwargs ["position_ids" ] = batch .position_ids
510- output = self .model .forward (** kwargs )
511-
512- embedding = self .pooling .forward (output , batch .attention_mask )
530+ embedding = self .model .forward (** kwargs )
513531
514532 cpu_results = embedding .view (- 1 ).tolist ()
515533
0 commit comments