-
Notifications
You must be signed in to change notification settings - Fork 50
Expand file tree
/
Copy pathentity_detection.py
More file actions
executable file
·51 lines (47 loc) · 2.1 KB
/
entity_detection.py
File metadata and controls
executable file
·51 lines (47 loc) · 2.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from torch import nn
import torch.nn.functional as F
class EntityDetection(nn.Module):
def __init__(self, config):
super(EntityDetection, self).__init__()
self.config = config
target_size = config.label
self.embed = nn.Embedding(config.words_num, config.words_dim)
if config.train_embed == False:
self.embed.weight.requires_grad = False
if config.entity_detection_mode.upper() == 'LSTM':
self.lstm = nn.LSTM(input_size=config.words_dim,
hidden_size=config.hidden_size,
num_layers=config.num_layer,
dropout=config.rnn_dropout,
bidirectional=True)
elif config.entity_detection_mode.upper() == 'GRU':
self.gru = nn.GRU(input_size=config.words_dim,
hidden_size=config.hidden_size,
num_layers=config.num_layer,
dropout=config.rnn_dropout,
bidirectional=True)
self.dropout = nn.Dropout(p=config.rnn_fc_dropout)
self.relu = nn.ReLU()
self.hidden2tag = nn.Sequential(
nn.Linear(config.hidden_size * 2, config.hidden_size * 2),
nn.BatchNorm1d(config.hidden_size * 2),
self.relu,
self.dropout,
nn.Linear(config.hidden_size * 2, target_size)
)
def forward(self, x):
# x = (sequence length, batch_size, dimension of embedding)
text = x.text
batch_size = text.size()[1]
x = self.embed(text)
# h0 / c0 = (layer*direction, batch_size, hidden_dim)
if self.config.entity_detection_mode.upper() == 'LSTM':
outputs, (ht, ct) = self.lstm(x)
elif self.config.entity_detection_mode.upper() == 'GRU':
outputs, ht = self.gru(x)
else:
print("Wrong Entity Prediction Mode")
exit(1)
tags = self.hidden2tag(outputs.view(-1, outputs.size(2)))
scores = F.log_softmax(tags, dim=1)
return scores