Skip to content

Commit 0f03d82

Browse files
committed
use mean_pooling instead
Signed-off-by: Liu, Kaixuan <[email protected]>
1 parent 9436c27 commit 0f03d82

File tree

1 file changed

+14
-22
lines changed

1 file changed

+14
-22
lines changed

backends/python/server/text_embeddings_server/models/jinaBert_model.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -454,28 +454,10 @@ def forward(
454454
return hidden_states
455455

456456

457-
class JinaBertPooler:
458-
def __init__(self, handle, device, dtype, config):
459-
self.dense_weight = (
460-
handle.get_tensor(f"pooler.dense.weight").to(dtype).to(device)
461-
)
462-
self.dense_bias = handle.get_tensor(f"pooler.dense.bias").to(dtype).to(device)
463-
self.activation = nn.Tanh()
464-
465-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
466-
# We "pool" the model by simply taking the hidden state corresponding
467-
# to the first token.
468-
first_token_tensor = hidden_states[:, 0]
469-
pooled_output = F.linear(first_token_tensor, self.dense_weight, self.dense_bias)
470-
pooled_output = self.activation(pooled_output)
471-
return pooled_output
472-
473-
474457
class FlashJinaBertModel:
475458
def __init__(self, handle, device, dtype, config: AutoConfig):
476459
self.embeddings = JinaBertEmbeddings(handle, device, dtype, config)
477460
self.encoder = JinaBertEncoder(handle, device, dtype, config)
478-
self.pooler = JinaBertPooler(handle, device, dtype, config)
479461

480462
def forward(
481463
self,
@@ -486,8 +468,7 @@ def forward(
486468
):
487469
embeddings = self.embeddings.forward(input_ids, token_type_ids, position_ids)
488470
encoder_outputs = self.encoder.forward(embeddings, attn_mask)
489-
pooled_output = self.pooler(encoder_outputs)
490-
return pooled_output
471+
return encoder_outputs
491472

492473

493474
class FlashJinaBert(Model):
@@ -522,13 +503,24 @@ def __init__(
522503
def batch_type(self) -> Type[PaddedBatch]:
523504
return PaddedBatch
524505

506+
def mean_pooling(
507+
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
508+
):
509+
input_mask_expanded = (
510+
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
511+
)
512+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
513+
input_mask_expanded.sum(1), min=1e-9
514+
)
515+
525516
@tracer.start_as_current_span("embed")
526517
def embed(self, batch: PaddedBatch) -> List[Embedding]:
527-
kwargs = {"input_ids": batch.input_ids, "attn_mask": batch.attention_mask}
518+
kwargs = {"input_ids": batch.input_ids}
528519
kwargs["token_type_ids"] = batch.token_type_ids
529520
kwargs["position_ids"] = batch.position_ids
530-
embedding = self.model.forward(**kwargs)
521+
outputs = self.model.forward(**kwargs)
531522

523+
embedding = self.mean_pooling(outputs, batch.attention_mask)
532524
cpu_results = embedding.view(-1).tolist()
533525

534526
return [

0 commit comments

Comments
 (0)