Skip to content

Commit ae5e44f

Browse files
committed
Make changes to unit tests to support GRUs
The unit tests for the TV script generation project seem to expect a tuple as the hidden state. However, in the case of GRUs, this is not the case. This commit changes the unit tests to deal with both a tuple as the hidden state as well as a single value
1 parent c53f439 commit ae5e44f

File tree

1 file changed

+22
-4
lines changed

1 file changed

+22
-4
lines changed

project-tv-script-generation/problem_unittests.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,6 @@ def test_rnn(RNN, train_on_gpu):
154154
b = torch.from_numpy(a)
155155
hidden = rnn.init_hidden(batch_size)
156156

157-
158157
if(train_on_gpu):
159158
rnn.cuda()
160159
b = b.cuda()
@@ -172,13 +171,27 @@ def test_rnn(RNN, train_on_gpu):
172171

173172
# initialization
174173
correct_hidden_size = (n_layers, batch_size, hidden_dim)
175-
assert_condition = hidden[0].size() == correct_hidden_size
174+
175+
if type(hidden) == tuple:
176+
# LSTM
177+
assert_condition = hidden[0].size() == correct_hidden_size
178+
else:
179+
# GRU
180+
assert_condition = hidden.size() == correct_hidden_size
181+
176182
assert_message = 'Wrong hidden state size. Expected type {}. Got type {}'.format(correct_hidden_size, hidden[0].size())
177183
assert_test.test(assert_condition, assert_message)
178184

179185
# output of rnn
180186
correct_hidden_size = (n_layers, batch_size, hidden_dim)
181-
assert_condition = hidden_out[0].size() == correct_hidden_size
187+
188+
if type(hidden) == tuple:
189+
# LSTM
190+
assert_condition = hidden_out[0].size() == correct_hidden_size
191+
else:
192+
# GRU
193+
assert_condition = hidden_out.size() == correct_hidden_size
194+
182195
assert_message = 'Wrong hidden state size. Expected type {}. Got type {}'.format(correct_hidden_size, hidden_out[0].size())
183196
assert_test.test(assert_condition, assert_message)
184197

@@ -218,7 +231,12 @@ def test_forward_back_prop(RNN, forward_back_prop, train_on_gpu):
218231

219232
loss, hidden_out = forward_back_prop(mock_decoder, mock_decoder_optimizer, mock_criterion, inp, target, hidden)
220233

221-
assert (hidden_out[0][0]==hidden[0][0]).sum()==batch_size*hidden_dim, 'Returned hidden state is the incorrect size.'
234+
if type(hidden_out) == tuple:
235+
# LSTM
236+
assert (hidden_out[0][0]==hidden[0][0]).sum()==batch_size*hidden_dim, 'Returned hidden state is the incorrect size.'
237+
else:
238+
# GRU
239+
assert (hidden_out[0]==hidden[0]).sum()==batch_size*hidden_dim, 'Returned hidden state is the incorrect size.'
222240

223241
assert mock_decoder.zero_grad.called or mock_decoder_optimizer.zero_grad.called, 'Didn\'t set the gradients to 0.'
224242
assert mock_decoder.forward_called, 'Forward propagation not called.'

0 commit comments

Comments
 (0)