Skip to content

Commit 580903e

Browse files
authored
Replace num_flat_features with torch.flatten (pytorch#1505)
Apply the best practices from alexnet https://github.com/pytorch/vision/blob/master/torchvision/models/alexnet.py#L48
1 parent 0244815 commit 580903e

File tree

2 files changed

+3
-10
lines changed

2 files changed

+3
-10
lines changed

beginner_source/blitz/cifar10_tutorial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def __init__(self):
136136
def forward(self, x):
137137
x = self.pool(F.relu(self.conv1(x)))
138138
x = self.pool(F.relu(self.conv2(x)))
139-
x = x.view(-1, 16 * 5 * 5)
139+
x = torch.flatten(x, 1) # flatten all dimensions except batch
140140
x = F.relu(self.fc1(x))
141141
x = F.relu(self.fc2(x))
142142
x = self.fc3(x)

beginner_source/blitz/neural_networks_tutorial.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,19 +60,12 @@ def forward(self, x):
6060
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
6161
# If the size is a square, you can specify with a single number
6262
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
63-
x = x.view(-1, self.num_flat_features(x))
63+
x = torch.flatten(x, 1) # flatten all dimensions except the batch dimension
6464
x = F.relu(self.fc1(x))
6565
x = F.relu(self.fc2(x))
6666
x = self.fc3(x)
6767
return x
6868

69-
def num_flat_features(self, x):
70-
size = x.size()[1:] # all dimensions except the batch dimension
71-
num_features = 1
72-
for s in size:
73-
num_features *= s
74-
return num_features
75-
7669

7770
net = Net()
7871
print(net)
@@ -171,7 +164,7 @@ def num_flat_features(self, x):
171164
# ::
172165
#
173166
# input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d
174-
# -> view -> linear -> relu -> linear -> relu -> linear
167+
# -> flatten -> linear -> relu -> linear -> relu -> linear
175168
# -> MSELoss
176169
# -> loss
177170
#

0 commit comments

Comments
 (0)