@@ -154,7 +154,6 @@ def test_rnn(RNN, train_on_gpu):
154
154
b = torch .from_numpy (a )
155
155
hidden = rnn .init_hidden (batch_size )
156
156
157
-
158
157
if (train_on_gpu ):
159
158
rnn .cuda ()
160
159
b = b .cuda ()
@@ -172,13 +171,27 @@ def test_rnn(RNN, train_on_gpu):
172
171
173
172
# initialization
174
173
correct_hidden_size = (n_layers , batch_size , hidden_dim )
175
- assert_condition = hidden [0 ].size () == correct_hidden_size
174
+
175
+ if type (hidden ) == tuple :
176
+ # LSTM
177
+ assert_condition = hidden [0 ].size () == correct_hidden_size
178
+ else :
179
+ # GRU
180
+ assert_condition = hidden .size () == correct_hidden_size
181
+
176
182
assert_message = 'Wrong hidden state size. Expected type {}. Got type {}' .format (correct_hidden_size , hidden [0 ].size ())
177
183
assert_test .test (assert_condition , assert_message )
178
184
179
185
# output of rnn
180
186
correct_hidden_size = (n_layers , batch_size , hidden_dim )
181
- assert_condition = hidden_out [0 ].size () == correct_hidden_size
187
+
188
+ if type (hidden ) == tuple :
189
+ # LSTM
190
+ assert_condition = hidden_out [0 ].size () == correct_hidden_size
191
+ else :
192
+ # GRU
193
+ assert_condition = hidden_out .size () == correct_hidden_size
194
+
182
195
assert_message = 'Wrong hidden state size. Expected type {}. Got type {}' .format (correct_hidden_size , hidden_out [0 ].size ())
183
196
assert_test .test (assert_condition , assert_message )
184
197
@@ -218,7 +231,12 @@ def test_forward_back_prop(RNN, forward_back_prop, train_on_gpu):
218
231
219
232
loss , hidden_out = forward_back_prop (mock_decoder , mock_decoder_optimizer , mock_criterion , inp , target , hidden )
220
233
221
- assert (hidden_out [0 ][0 ]== hidden [0 ][0 ]).sum ()== batch_size * hidden_dim , 'Returned hidden state is the incorrect size.'
234
+ if type (hidden_out ) == tuple :
235
+ # LSTM
236
+ assert (hidden_out [0 ][0 ]== hidden [0 ][0 ]).sum ()== batch_size * hidden_dim , 'Returned hidden state is the incorrect size.'
237
+ else :
238
+ # GRU
239
+ assert (hidden_out [0 ]== hidden [0 ]).sum ()== batch_size * hidden_dim , 'Returned hidden state is the incorrect size.'
222
240
223
241
assert mock_decoder .zero_grad .called or mock_decoder_optimizer .zero_grad .called , 'Didn\' t set the gradients to 0.'
224
242
assert mock_decoder .forward_called , 'Forward propagation not called.'
0 commit comments