Skip to content

Commit 066b4d9

Browse files
FaranIdosubramen
andauthored
Fix testloss calculation in Quickstart and Optimization tutorials (pytorch#1536)
* Update optimization_tutorial.py * Update quickstart_tutorial.py * Fix typo in optimization tutorial Co-authored-by: suraj813 <[email protected]>
1 parent e8b6687 commit 066b4d9

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

beginner_source/basics/optimization_tutorial.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def train_loop(dataloader, model, loss_fn, optimizer):
167167

168168
def test_loop(dataloader, model, loss_fn):
169169
size = len(dataloader.dataset)
170+
num_batches = len(dataloader)
170171
test_loss, correct = 0, 0
171172

172173
with torch.no_grad():
@@ -175,7 +176,7 @@ def test_loop(dataloader, model, loss_fn):
175176
test_loss += loss_fn(pred, y).item()
176177
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
177178

178-
test_loss /= size
179+
test_loss /= num_batches
179180
correct /= size
180181
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
181182

beginner_source/basics/quickstart_tutorial.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def train(dataloader, model, loss_fn, optimizer):
160160

161161
def test(dataloader, model, loss_fn):
162162
size = len(dataloader.dataset)
163+
num_batches = len(dataloader)
163164
model.eval()
164165
test_loss, correct = 0, 0
165166
with torch.no_grad():
@@ -168,7 +169,7 @@ def test(dataloader, model, loss_fn):
168169
pred = model(X)
169170
test_loss += loss_fn(pred, y).item()
170171
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
171-
test_loss /= size
172+
test_loss /= num_batches
172173
correct /= size
173174
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
174175

0 commit comments

Comments
 (0)