Skip to content

Commit 71cb589

Browse files
authored
Merge pull request #325 from donglihe-hub/embed
Extend arguments accepted by Embedding
2 parents cb651c1 + f05f883 commit 71cb589

File tree

5 files changed

+30
-9
lines changed

5 files changed

+30
-9
lines changed

docs/examples/plot_KimCNN_quickstart.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,12 @@
4545
# We consider the following settings for the KimCNN model.
4646

4747
model_name = "KimCNN"
48-
network_config = {"embed_dropout": 0.2, "post_encoder_dropout": 0.2, "filter_sizes": [2, 4, 8], "num_filter_per_size": 128}
48+
network_config = {
49+
"embed_dropout": 0.2,
50+
"post_encoder_dropout": 0.2,
51+
"filter_sizes": [2, 4, 8],
52+
"num_filter_per_size": 128,
53+
}
4954
learning_rate = 0.0003
5055
model = init_model(
5156
model_name=model_name,

libmultilabel/nn/networks/kim_cnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(
2828
activation="relu",
2929
):
3030
super(KimCNN, self).__init__()
31-
self.embedding = Embedding(embed_vecs, embed_dropout)
31+
self.embedding = Embedding(embed_vecs, dropout=embed_dropout)
3232
self.encoder = CNNEncoder(
3333
embed_vecs.shape[1], filter_sizes, num_filter_per_size, activation, post_encoder_dropout, num_pool=1
3434
)

libmultilabel/nn/networks/labelwise_attention_networks.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class LabelwiseAttentionNetwork(ABC, nn.Module):
2727

2828
def __init__(self, embed_vecs, num_classes, embed_dropout, encoder_dropout, post_encoder_dropout, hidden_dim):
2929
super(LabelwiseAttentionNetwork, self).__init__()
30-
self.embedding = Embedding(embed_vecs, embed_dropout)
30+
self.embedding = Embedding(embed_vecs, dropout=embed_dropout)
3131
self.encoder = self._get_encoder(embed_vecs.shape[1], encoder_dropout, post_encoder_dropout)
3232
self.attention = self._get_attention()
3333
self.output = LabelwiseLinearOutput(hidden_dim, num_classes)
@@ -199,7 +199,9 @@ def _get_encoder(self, input_size, encoder_dropout, post_encoder_dropout):
199199
return LSTMEncoder(input_size, self.rnn_dim // 2, self.rnn_layers, encoder_dropout, post_encoder_dropout)
200200

201201
def _get_attention(self):
202-
return LabelwiseMultiHeadAttention(self.rnn_dim, self.num_classes, self.num_heads, self.labelwise_attention_dropout)
202+
return LabelwiseMultiHeadAttention(
203+
self.rnn_dim, self.num_classes, self.num_heads, self.labelwise_attention_dropout
204+
)
203205

204206
def forward(self, input):
205207
# (batch_size, sequence_length, embed_dim)
@@ -246,7 +248,12 @@ def __init__(
246248
def _get_encoder(self, input_size, encoder_dropout, post_encoder_dropout):
247249
# encoder dropout is unused for CNN, we accept it to satisfy LabelwiseAttentionNetwork API
248250
return CNNEncoder(
249-
input_size, self.filter_sizes, self.num_filter_per_size, self.activation, post_encoder_dropout, channel_last=True
251+
input_size,
252+
self.filter_sizes,
253+
self.num_filter_per_size,
254+
self.activation,
255+
post_encoder_dropout,
256+
channel_last=True,
250257
)
251258

252259
def _get_attention(self):

libmultilabel/nn/networks/modules.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@ class Embedding(nn.Module):
1111
1212
Args:
1313
embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim).
14+
freeze (bool): If True, the tensor does not get updated in the learning process.
15+
Equivalent to embedding.weight.requires_grad = False. Default: False.
1416
dropout (float): The dropout rate of the word embedding. Defaults to 0.2.
1517
"""
1618

17-
def __init__(self, embed_vecs, dropout=0.2):
19+
def __init__(self, embed_vecs, freeze=False, dropout=0.2):
1820
super(Embedding, self).__init__()
19-
self.embedding = nn.Embedding.from_pretrained(embed_vecs, freeze=False, padding_idx=0)
21+
self.embedding = nn.Embedding.from_pretrained(embed_vecs, freeze=freeze, padding_idx=0)
2022
self.dropout = nn.Dropout(dropout)
2123

2224
def forward(self, input):
@@ -105,7 +107,14 @@ class CNNEncoder(nn.Module):
105107
"""
106108

107109
def __init__(
108-
self, input_size, filter_sizes, num_filter_per_size, activation, post_encoder_dropout=0, num_pool=0, channel_last=False
110+
self,
111+
input_size,
112+
filter_sizes,
113+
num_filter_per_size,
114+
activation,
115+
post_encoder_dropout=0,
116+
num_pool=0,
117+
channel_last=False,
109118
):
110119
super(CNNEncoder, self).__init__()
111120
if not filter_sizes:

libmultilabel/nn/networks/xml_cnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(
3333
activation="relu",
3434
):
3535
super(XMLCNN, self).__init__()
36-
self.embedding = Embedding(embed_vecs, embed_dropout)
36+
self.embedding = Embedding(embed_vecs, dropout=embed_dropout)
3737
self.encoder = CNNEncoder(embed_vecs.shape[1], filter_sizes, num_filter_per_size, activation, num_pool=num_pool)
3838
total_output_size = len(filter_sizes) * num_filter_per_size * num_pool
3939
self.linear1 = nn.Linear(total_output_size, hidden_dim)

0 commit comments

Comments
 (0)