9
9
10
10
Channels last memory format is an alternative way of ordering NCHW tensors in memory preserving dimensions ordering. Channels last tensors ordered in such a way that channels become the densest dimension (aka storing images pixel-per-pixel).
11
11
12
- For example, classic (contiguous) storage of NCHW tensor (in our case it is two 2x2 images with 3 color channels) look like this:
12
+ For example, classic (contiguous) storage of NCHW tensor (in our case it is two 4x4 images with 3 color channels) look like this:
13
13
14
14
.. figure:: /_static/img/classic_memory_format.png
15
15
:alt: classic_memory_format
37
37
######################################################################
38
38
# Classic PyTorch contiguous tensor
39
39
import torch
40
+
40
41
N , C , H , W = 10 , 3 , 32 , 32
41
42
x = torch .empty (N , C , H , W )
42
- print (x .stride ()) # Ouputs: (3072, 1024, 32, 1)
43
+ print (x .stride ()) # Ouputs: (3072, 1024, 32, 1)
43
44
44
45
######################################################################
45
46
# Conversion operator
46
47
x = x .to (memory_format = torch .channels_last )
47
- print (x .shape ) # Outputs: (10, 3, 32, 32) as dimensions order preserved
48
- print (x .stride ()) # Outputs: (3072, 1, 96, 3)
48
+ print (x .shape ) # Outputs: (10, 3, 32, 32) as dimensions order preserved
49
+ print (x .stride ()) # Outputs: (3072, 1, 96, 3)
49
50
50
51
######################################################################
51
52
# Back to contiguous
52
53
x = x .to (memory_format = torch .contiguous_format )
53
- print (x .stride ()) # Outputs: (3072, 1024, 32, 1)
54
+ print (x .stride ()) # Outputs: (3072, 1024, 32, 1)
54
55
55
56
######################################################################
56
57
# Alternative option
57
58
x = x .contiguous (memory_format = torch .channels_last )
58
- print (x .stride ()) # Ouputs: (3072, 1, 96, 3)
59
+ print (x .stride ()) # Ouputs: (3072, 1, 96, 3)
59
60
60
61
######################################################################
61
62
# Format checks
62
- print (x .is_contiguous (memory_format = torch .channels_last )) # Ouputs: True
63
+ print (x .is_contiguous (memory_format = torch .channels_last )) # Ouputs: True
63
64
64
65
######################################################################
65
66
# There are minor difference between the two APIs ``to`` and
81
82
# sizes are 1 in order to properly represent the intended memory
82
83
# format
83
84
special_x = torch .empty (4 , 1 , 4 , 4 )
84
- print (special_x .is_contiguous (memory_format = torch .channels_last )) # Ouputs: True
85
- print (special_x .is_contiguous (memory_format = torch .contiguous_format )) # Ouputs: True
85
+ print (special_x .is_contiguous (memory_format = torch .channels_last )) # Ouputs: True
86
+ print (special_x .is_contiguous (memory_format = torch .contiguous_format )) # Ouputs: True
86
87
87
88
######################################################################
88
89
# Same thing applies to explicit permutation API ``permute``. In
99
100
######################################################################
100
101
# Create as channels last
101
102
x = torch .empty (N , C , H , W , memory_format = torch .channels_last )
102
- print (x .stride ()) # Ouputs: (3072, 1, 96, 3)
103
+ print (x .stride ()) # Ouputs: (3072, 1, 96, 3)
103
104
104
105
######################################################################
105
106
# ``clone`` preserves memory format
106
107
y = x .clone ()
107
- print (y .stride ()) # Ouputs: (3072, 1, 96, 3)
108
+ print (y .stride ()) # Ouputs: (3072, 1, 96, 3)
108
109
109
110
######################################################################
110
111
# ``to``, ``cuda``, ``float`` ... preserves memory format
111
112
if torch .cuda .is_available ():
112
113
y = x .cuda ()
113
- print (y .stride ()) # Ouputs: (3072, 1, 96, 3)
114
+ print (y .stride ()) # Ouputs: (3072, 1, 96, 3)
114
115
115
116
######################################################################
116
117
# ``empty_like``, ``*_like`` operators preserves memory format
117
118
y = torch .empty_like (x )
118
- print (y .stride ()) # Ouputs: (3072, 1, 96, 3)
119
+ print (y .stride ()) # Ouputs: (3072, 1, 96, 3)
119
120
120
121
######################################################################
121
122
# Pointwise operators preserves memory format
122
123
z = x + y
123
- print (z .stride ()) # Ouputs: (3072, 1, 96, 3)
124
+ print (z .stride ()) # Ouputs: (3072, 1, 96, 3)
124
125
125
126
######################################################################
126
127
# Conv, Batchnorm modules using cudnn backends support channels last
132
133
133
134
if torch .backends .cudnn .version () >= 7603 :
134
135
model = torch .nn .Conv2d (8 , 4 , 3 ).cuda ().half ()
135
- model = model .to (memory_format = torch .channels_last ) # Module parameters need to be channels last
136
+ model = model .to (memory_format = torch .channels_last ) # Module parameters need to be channels last
136
137
137
138
input = torch .randint (1 , 10 , (2 , 8 , 4 , 4 ), dtype = torch .float32 , requires_grad = True )
138
139
input = input .to (device = "cuda" , memory_format = torch .channels_last , dtype = torch .float16 )
139
140
140
141
out = model (input )
141
- print (out .is_contiguous (memory_format = torch .channels_last )) # Ouputs: True
142
+ print (out .is_contiguous (memory_format = torch .channels_last )) # Ouputs: True
142
143
143
144
######################################################################
144
145
# When input tensor reaches a operator without channels last support,
195
196
196
197
######################################################################
197
198
# Passing ``--channels-last true`` allows running a model in Channels last format with observed 22% perf gain.
198
- #
199
+ #
199
200
# ``python main_amp.py -a resnet50 --b 200 --workers 16 --opt-level O2 --channels-last true ./data``
200
201
201
202
# opt_level = O2
250
251
#
251
252
252
253
# Need to be done once, after model initialization (or load)
253
- model = model .to (memory_format = torch .channels_last ) # Replace with your model
254
+ model = model .to (memory_format = torch .channels_last ) # Replace with your model
254
255
255
256
# Need to be done for every input
256
- input = input .to (memory_format = torch .channels_last ) # Replace with your input
257
+ input = input .to (memory_format = torch .channels_last ) # Replace with your input
257
258
output = model (input )
258
259
259
260
#######################################################################
271
272
# operatos in your model that does not support channels last, if you
272
273
# want to improve the performance of converted model.
273
274
#
274
- # That means you need to verify the list of used operators
275
- # against supported operators list https://github.com/pytorch/pytorch/wiki/Operators-with-Channels-Last-support,
275
+ # That means you need to verify the list of used operators
276
+ # against supported operators list https://github.com/pytorch/pytorch/wiki/Operators-with-Channels-Last-support,
276
277
# or introduce memory format checks into eager execution mode and run your model.
277
278
#
278
279
# After running the code below, operators will raise an exception if the output of the
@@ -290,13 +291,13 @@ def contains_cl(args):
290
291
return False
291
292
292
293
293
- def print_inputs (args , indent = '' ):
294
+ def print_inputs (args , indent = "" ):
294
295
for t in args :
295
296
if isinstance (t , torch .Tensor ):
296
297
print (indent , t .stride (), t .shape , t .device , t .dtype )
297
298
elif isinstance (t , list ) or isinstance (t , tuple ):
298
299
print (indent , type (t ))
299
- print_inputs (list (t ), indent = indent + ' ' )
300
+ print_inputs (list (t ), indent = indent + " " )
300
301
else :
301
302
print (indent , t )
302
303
@@ -311,32 +312,38 @@ def check_cl(*args, **kwargs):
311
312
except Exception as e :
312
313
print ("`{}` inputs are:" .format (name ))
313
314
print_inputs (args )
314
- print (' -------------------' )
315
+ print (" -------------------" )
315
316
raise e
316
317
failed = False
317
318
if was_cl :
318
319
if isinstance (result , torch .Tensor ):
319
320
if result .dim () == 4 and not result .is_contiguous (memory_format = torch .channels_last ):
320
- print ("`{}` got channels_last input, but output is not channels_last:" .format (name ),
321
- result .shape , result .stride (), result .device , result .dtype )
321
+ print (
322
+ "`{}` got channels_last input, but output is not channels_last:" .format (name ),
323
+ result .shape ,
324
+ result .stride (),
325
+ result .device ,
326
+ result .dtype ,
327
+ )
322
328
failed = True
323
329
if failed and True :
324
330
print ("`{}` inputs are:" .format (name ))
325
331
print_inputs (args )
326
- raise Exception (
327
- 'Operator `{}` lost channels_last property' .format (name ))
332
+ raise Exception ("Operator `{}` lost channels_last property" .format (name ))
328
333
return result
334
+
329
335
return check_cl
330
336
337
+
331
338
old_attrs = dict ()
332
339
340
+
333
341
def attribute (m ):
334
342
old_attrs [m ] = dict ()
335
343
for i in dir (m ):
336
344
e = getattr (m , i )
337
- exclude_functions = ['is_cuda' , 'has_names' , 'numel' ,
338
- 'stride' , 'Tensor' , 'is_contiguous' , '__class__' ]
339
- if i not in exclude_functions and not i .startswith ('_' ) and '__call__' in dir (e ):
345
+ exclude_functions = ["is_cuda" , "has_names" , "numel" , "stride" , "Tensor" , "is_contiguous" , "__class__" ]
346
+ if i not in exclude_functions and not i .startswith ("_" ) and "__call__" in dir (e ):
340
347
try :
341
348
old_attrs [m ][i ] = e
342
349
setattr (m , i , check_wrapper (e ))
@@ -352,16 +359,16 @@ def attribute(m):
352
359
353
360
######################################################################
354
361
# If you found an operator that doesn't support channels last tensors
355
- # and you want to contribute, feel free to use following developers
362
+ # and you want to contribute, feel free to use following developers
356
363
# guide https://github.com/pytorch/pytorch/wiki/Writing-memory-format-aware-operators.
357
364
#
358
365
359
366
######################################################################
360
367
# Code below is to recover the attributes of torch.
361
368
362
369
for (m , attrs ) in old_attrs .items ():
363
- for (k ,v ) in attrs .items ():
364
- setattr (m , k , v )
370
+ for (k , v ) in attrs .items ():
371
+ setattr (m , k , v )
365
372
366
373
######################################################################
367
374
# Work to do
0 commit comments