Difference between nn.Flatten() and nn.Flatten(start_dim=0) #388
Replies: 1 comment
-
Hi @avimittal30, Here's a quick example to explain For example, let's create a tensor with the numbers 0 to 9 and reshape it and then try out different values of flattening. See the example notebook here: https://colab.research.google.com/drive/1tjr1K03miTzC9bXj-bDb3Ppndf__O0f0?usp=sharing import torch
from torch import nn
# Create a tensor with numbers 0 to 10 and reshape to multiple dimensions
x = torch.arange(0, 10).reshape(2, 5, 1, 1)
x, x.shape Output:
Next we'll try out the default # Remove the 1 dimensions (flatten to 2, 5)
flatten = nn.Flatten(start_dim=1) # flatten starting on dimension 1 (the 5)
output = flatten(x)
output, output.shape Output:
Finally we'll flatten the tensor to the 0th dimension with # Flatten on the 0 dimension (the 2)
flatten_1 = nn.Flatten(start_dim=0) # flattens *all* dimensions into the 0th dimension
output_1 = flatten_1(x)
output_1, output_1.shape Output:
For more of an understanding, I'd encourage you to try it out for yourself in the notebook above. Practice making different tensors and exploring different options with the |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Can anyone explain this with the help of an example
Beta Was this translation helpful? Give feedback.
All reactions