Skip to content

Commit 9436c27

Browse files
committed
add related pooler process
Signed-off-by: Liu, Kaixuan <[email protected]>
1 parent f49aeea commit 9436c27

File tree

1 file changed

+26
-8
lines changed

1 file changed

+26
-8
lines changed

backends/python/server/text_embeddings_server/models/jinaBert_model.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def forward(
108108
self.layernorm_weight.shape,
109109
self.layernorm_weight,
110110
self.layernorm_bias,
111-
eps=self.config.layer_norm_eps
111+
eps=self.config.layer_norm_eps,
112112
)
113113
embeddings = self.dropout(embeddings)
114114
return embeddings
@@ -174,11 +174,12 @@ def forward(
174174
self.layer_norm_q_weight.shape,
175175
self.layer_norm_q_weight,
176176
self.layer_norm_q_bias,
177-
eps=self.config.layer_norm_eps,)
177+
eps=self.config.layer_norm_eps,
178+
)
178179

179180
k_hidden_states = F.linear(hidden_states, self.key_weight, self.key_bias)
180181
key_layer = self.transpose_for_scores(
181-
F.layer_norm(
182+
F.layer_norm(
182183
k_hidden_states,
183184
self.layer_norm_k_weight.shape,
184185
self.layer_norm_k_weight,
@@ -237,7 +238,7 @@ def forward(
237238
hidden_states = F.linear(hidden_states, self.dense_weight, self.dense_bias)
238239
hidden_states = self.dropout(hidden_states)
239240
hidden_states = F.layer_norm(
240-
hidden_states+input_tensor,
241+
hidden_states + input_tensor,
241242
self.layerNorm_weight.shape,
242243
self.layerNorm_weight,
243244
self.layerNorm_bias,
@@ -453,10 +454,28 @@ def forward(
453454
return hidden_states
454455

455456

457+
class JinaBertPooler:
458+
def __init__(self, handle, device, dtype, config):
459+
self.dense_weight = (
460+
handle.get_tensor(f"pooler.dense.weight").to(dtype).to(device)
461+
)
462+
self.dense_bias = handle.get_tensor(f"pooler.dense.bias").to(dtype).to(device)
463+
self.activation = nn.Tanh()
464+
465+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
466+
# We "pool" the model by simply taking the hidden state corresponding
467+
# to the first token.
468+
first_token_tensor = hidden_states[:, 0]
469+
pooled_output = F.linear(first_token_tensor, self.dense_weight, self.dense_bias)
470+
pooled_output = self.activation(pooled_output)
471+
return pooled_output
472+
473+
456474
class FlashJinaBertModel:
457475
def __init__(self, handle, device, dtype, config: AutoConfig):
458476
self.embeddings = JinaBertEmbeddings(handle, device, dtype, config)
459477
self.encoder = JinaBertEncoder(handle, device, dtype, config)
478+
self.pooler = JinaBertPooler(handle, device, dtype, config)
460479

461480
def forward(
462481
self,
@@ -467,7 +486,8 @@ def forward(
467486
):
468487
embeddings = self.embeddings.forward(input_ids, token_type_ids, position_ids)
469488
encoder_outputs = self.encoder.forward(embeddings, attn_mask)
470-
return encoder_outputs[0]
489+
pooled_output = self.pooler(encoder_outputs)
490+
return pooled_output
471491

472492

473493
class FlashJinaBert(Model):
@@ -507,9 +527,7 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]:
507527
kwargs = {"input_ids": batch.input_ids, "attn_mask": batch.attention_mask}
508528
kwargs["token_type_ids"] = batch.token_type_ids
509529
kwargs["position_ids"] = batch.position_ids
510-
output = self.model.forward(**kwargs)
511-
512-
embedding = self.pooling.forward(output, batch.attention_mask)
530+
embedding = self.model.forward(**kwargs)
513531

514532
cpu_results = embedding.view(-1).tolist()
515533

0 commit comments

Comments
 (0)