Skip to content

Commit 4e4730c

Browse files
authored
Fix text and image mismatch (pytorch#1630)
Signed-off-by: Behrooz <[email protected]> Thank you @drbeh!
1 parent e90a8a2 commit 4e4730c

File tree

1 file changed

+41
-34
lines changed

1 file changed

+41
-34
lines changed

intermediate_source/memory_format_tutorial.py

Lines changed: 41 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
1010
Channels last memory format is an alternative way of ordering NCHW tensors in memory preserving dimensions ordering. Channels last tensors ordered in such a way that channels become the densest dimension (aka storing images pixel-per-pixel).
1111
12-
For example, classic (contiguous) storage of NCHW tensor (in our case it is two 2x2 images with 3 color channels) look like this:
12+
For example, classic (contiguous) storage of NCHW tensor (in our case it is two 4x4 images with 3 color channels) look like this:
1313
1414
.. figure:: /_static/img/classic_memory_format.png
1515
:alt: classic_memory_format
@@ -37,29 +37,30 @@
3737
######################################################################
3838
# Classic PyTorch contiguous tensor
3939
import torch
40+
4041
N, C, H, W = 10, 3, 32, 32
4142
x = torch.empty(N, C, H, W)
42-
print(x.stride()) # Ouputs: (3072, 1024, 32, 1)
43+
print(x.stride()) # Ouputs: (3072, 1024, 32, 1)
4344

4445
######################################################################
4546
# Conversion operator
4647
x = x.to(memory_format=torch.channels_last)
47-
print(x.shape) # Outputs: (10, 3, 32, 32) as dimensions order preserved
48-
print(x.stride()) # Outputs: (3072, 1, 96, 3)
48+
print(x.shape) # Outputs: (10, 3, 32, 32) as dimensions order preserved
49+
print(x.stride()) # Outputs: (3072, 1, 96, 3)
4950

5051
######################################################################
5152
# Back to contiguous
5253
x = x.to(memory_format=torch.contiguous_format)
53-
print(x.stride()) # Outputs: (3072, 1024, 32, 1)
54+
print(x.stride()) # Outputs: (3072, 1024, 32, 1)
5455

5556
######################################################################
5657
# Alternative option
5758
x = x.contiguous(memory_format=torch.channels_last)
58-
print(x.stride()) # Ouputs: (3072, 1, 96, 3)
59+
print(x.stride()) # Ouputs: (3072, 1, 96, 3)
5960

6061
######################################################################
6162
# Format checks
62-
print(x.is_contiguous(memory_format=torch.channels_last)) # Ouputs: True
63+
print(x.is_contiguous(memory_format=torch.channels_last)) # Ouputs: True
6364

6465
######################################################################
6566
# There are minor difference between the two APIs ``to`` and
@@ -81,8 +82,8 @@
8182
# sizes are 1 in order to properly represent the intended memory
8283
# format
8384
special_x = torch.empty(4, 1, 4, 4)
84-
print(special_x.is_contiguous(memory_format=torch.channels_last)) # Ouputs: True
85-
print(special_x.is_contiguous(memory_format=torch.contiguous_format)) # Ouputs: True
85+
print(special_x.is_contiguous(memory_format=torch.channels_last)) # Ouputs: True
86+
print(special_x.is_contiguous(memory_format=torch.contiguous_format)) # Ouputs: True
8687

8788
######################################################################
8889
# Same thing applies to explicit permutation API ``permute``. In
@@ -99,28 +100,28 @@
99100
######################################################################
100101
# Create as channels last
101102
x = torch.empty(N, C, H, W, memory_format=torch.channels_last)
102-
print(x.stride()) # Ouputs: (3072, 1, 96, 3)
103+
print(x.stride()) # Ouputs: (3072, 1, 96, 3)
103104

104105
######################################################################
105106
# ``clone`` preserves memory format
106107
y = x.clone()
107-
print(y.stride()) # Ouputs: (3072, 1, 96, 3)
108+
print(y.stride()) # Ouputs: (3072, 1, 96, 3)
108109

109110
######################################################################
110111
# ``to``, ``cuda``, ``float`` ... preserves memory format
111112
if torch.cuda.is_available():
112113
y = x.cuda()
113-
print(y.stride()) # Ouputs: (3072, 1, 96, 3)
114+
print(y.stride()) # Ouputs: (3072, 1, 96, 3)
114115

115116
######################################################################
116117
# ``empty_like``, ``*_like`` operators preserves memory format
117118
y = torch.empty_like(x)
118-
print(y.stride()) # Ouputs: (3072, 1, 96, 3)
119+
print(y.stride()) # Ouputs: (3072, 1, 96, 3)
119120

120121
######################################################################
121122
# Pointwise operators preserves memory format
122123
z = x + y
123-
print(z.stride()) # Ouputs: (3072, 1, 96, 3)
124+
print(z.stride()) # Ouputs: (3072, 1, 96, 3)
124125

125126
######################################################################
126127
# Conv, Batchnorm modules using cudnn backends support channels last
@@ -132,13 +133,13 @@
132133

133134
if torch.backends.cudnn.version() >= 7603:
134135
model = torch.nn.Conv2d(8, 4, 3).cuda().half()
135-
model = model.to(memory_format=torch.channels_last) # Module parameters need to be channels last
136+
model = model.to(memory_format=torch.channels_last) # Module parameters need to be channels last
136137

137138
input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, requires_grad=True)
138139
input = input.to(device="cuda", memory_format=torch.channels_last, dtype=torch.float16)
139140

140141
out = model(input)
141-
print(out.is_contiguous(memory_format=torch.channels_last)) # Ouputs: True
142+
print(out.is_contiguous(memory_format=torch.channels_last)) # Ouputs: True
142143

143144
######################################################################
144145
# When input tensor reaches a operator without channels last support,
@@ -195,7 +196,7 @@
195196

196197
######################################################################
197198
# Passing ``--channels-last true`` allows running a model in Channels last format with observed 22% perf gain.
198-
#
199+
#
199200
# ``python main_amp.py -a resnet50 --b 200 --workers 16 --opt-level O2 --channels-last true ./data``
200201

201202
# opt_level = O2
@@ -250,10 +251,10 @@
250251
#
251252

252253
# Need to be done once, after model initialization (or load)
253-
model = model.to(memory_format=torch.channels_last) # Replace with your model
254+
model = model.to(memory_format=torch.channels_last) # Replace with your model
254255

255256
# Need to be done for every input
256-
input = input.to(memory_format=torch.channels_last) # Replace with your input
257+
input = input.to(memory_format=torch.channels_last) # Replace with your input
257258
output = model(input)
258259

259260
#######################################################################
@@ -271,8 +272,8 @@
271272
# operatos in your model that does not support channels last, if you
272273
# want to improve the performance of converted model.
273274
#
274-
# That means you need to verify the list of used operators
275-
# against supported operators list https://github.com/pytorch/pytorch/wiki/Operators-with-Channels-Last-support,
275+
# That means you need to verify the list of used operators
276+
# against supported operators list https://github.com/pytorch/pytorch/wiki/Operators-with-Channels-Last-support,
276277
# or introduce memory format checks into eager execution mode and run your model.
277278
#
278279
# After running the code below, operators will raise an exception if the output of the
@@ -290,13 +291,13 @@ def contains_cl(args):
290291
return False
291292

292293

293-
def print_inputs(args, indent=''):
294+
def print_inputs(args, indent=""):
294295
for t in args:
295296
if isinstance(t, torch.Tensor):
296297
print(indent, t.stride(), t.shape, t.device, t.dtype)
297298
elif isinstance(t, list) or isinstance(t, tuple):
298299
print(indent, type(t))
299-
print_inputs(list(t), indent=indent + ' ')
300+
print_inputs(list(t), indent=indent + " ")
300301
else:
301302
print(indent, t)
302303

@@ -311,32 +312,38 @@ def check_cl(*args, **kwargs):
311312
except Exception as e:
312313
print("`{}` inputs are:".format(name))
313314
print_inputs(args)
314-
print('-------------------')
315+
print("-------------------")
315316
raise e
316317
failed = False
317318
if was_cl:
318319
if isinstance(result, torch.Tensor):
319320
if result.dim() == 4 and not result.is_contiguous(memory_format=torch.channels_last):
320-
print("`{}` got channels_last input, but output is not channels_last:".format(name),
321-
result.shape, result.stride(), result.device, result.dtype)
321+
print(
322+
"`{}` got channels_last input, but output is not channels_last:".format(name),
323+
result.shape,
324+
result.stride(),
325+
result.device,
326+
result.dtype,
327+
)
322328
failed = True
323329
if failed and True:
324330
print("`{}` inputs are:".format(name))
325331
print_inputs(args)
326-
raise Exception(
327-
'Operator `{}` lost channels_last property'.format(name))
332+
raise Exception("Operator `{}` lost channels_last property".format(name))
328333
return result
334+
329335
return check_cl
330336

337+
331338
old_attrs = dict()
332339

340+
333341
def attribute(m):
334342
old_attrs[m] = dict()
335343
for i in dir(m):
336344
e = getattr(m, i)
337-
exclude_functions = ['is_cuda', 'has_names', 'numel',
338-
'stride', 'Tensor', 'is_contiguous', '__class__']
339-
if i not in exclude_functions and not i.startswith('_') and '__call__' in dir(e):
345+
exclude_functions = ["is_cuda", "has_names", "numel", "stride", "Tensor", "is_contiguous", "__class__"]
346+
if i not in exclude_functions and not i.startswith("_") and "__call__" in dir(e):
340347
try:
341348
old_attrs[m][i] = e
342349
setattr(m, i, check_wrapper(e))
@@ -352,16 +359,16 @@ def attribute(m):
352359

353360
######################################################################
354361
# If you found an operator that doesn't support channels last tensors
355-
# and you want to contribute, feel free to use following developers
362+
# and you want to contribute, feel free to use following developers
356363
# guide https://github.com/pytorch/pytorch/wiki/Writing-memory-format-aware-operators.
357364
#
358365

359366
######################################################################
360367
# Code below is to recover the attributes of torch.
361368

362369
for (m, attrs) in old_attrs.items():
363-
for (k,v) in attrs.items():
364-
setattr(m, k, v)
370+
for (k, v) in attrs.items():
371+
setattr(m, k, v)
365372

366373
######################################################################
367374
# Work to do

0 commit comments

Comments
 (0)