Open
Description
orginal code:
def forward(self, document_batch: torch.Tensor, device='cpu', bert_batch_size=0):
**bert_output = torch.zeros(size=(document_batch.shape[0],
min(document_batch.shape[1],
bert_batch_size),
self.bert.config.hidden_size), dtype=torch.float, device=device)
for doc_id in range(document_batch.shape[0]):
bert_output[doc_id][:bert_batch_size] = self.dropout(self.bert(document_batch[doc_id][:bert_batch_size,0],
token_type_ids=document_batch[doc_id][:bert_batch_size, 1],
attention_mask=document_batch[doc_id][:bert_batch_size, 2])[1])**
output, (_, _) = self.lstm(bert_output.permute(1, 0, 2))
output = output.permute(1, 0, 2)
# (batch_size, seq_len, num_hiddens)
attention_w = torch.tanh(torch.matmul(output, self.w_omega) + self.b_omega)
attention_u = torch.matmul(attention_w, self.u_omega) # (batch_size, seq_len, 1)
attention_score = F.softmax(attention_u, dim=1) # (batch_size, seq_len, 1)
attention_hidden = output * attention_score # (batch_size, seq_len, num_hiddens)
attention_hidden = torch.sum(attention_hidden, dim=1) # 加权求和 (batch_size, num_hiddens)
prediction = self.mlp(attention_hidden)
assert prediction.shape[0] == document_batch.shape[0]
return prediction
modified:
def forward(self, document_batch: torch.Tensor, device='cpu', bert_batch_size=0):
**bert_output = torch.zeros(size=(document_batch.shape[0],
# min(document_batch.shape[1], bert_batch_size),
document_batch.shape[1],
self.bert.config.hidden_size), dtype=torch.float, device=device)
for doc_id in range(document_batch.shape[0]):
bert_output[doc_id][:document_batch.shape[1]] = self.dropout(self.bert(document_batch[doc_id][:document_batch.shape[1], 0],
token_type_ids=document_batch[doc_id][:document_batch.shape[1], 1],
attention_mask=document_batch[doc_id][:document_batch.shape[1], 2])[1])**
output, (_, _) = self.lstm(bert_output.permute(1, 0, 2))
output = output.permute(1, 0, 2)
# (batch_size, seq_len, num_hiddens)
attention_w = torch.tanh(torch.matmul(output, self.w_omega) + self.b_omega)
attention_u = torch.matmul(attention_w, self.u_omega) # (batch_size, seq_len, 1)
attention_score = F.softmax(attention_u, dim=1) # (batch_size, seq_len, 1)
attention_hidden = output * attention_score # (batch_size, seq_len, num_hiddens)
attention_hidden = torch.sum(attention_hidden, dim=1) # 加权求和 (batch_size, num_hiddens)
prediction = self.mlp(attention_hidden)
assert prediction.shape[0] == document_batch.shape[0]
return prediction
Metadata
Metadata
Assignees
Labels
No labels