Skip to content

Commit a9bb29b

Browse files
authored
Add double backward custom function tutorial (pytorch#1603)
* Add double backward custom function tutorial * Move file to correct dir and make some updates * Move file to correct dir and make some updates * Add images * Rename * Update easy trap case * Fix bullet points * Try to fix scale * Fix index.rst link * Try fix scale again * Remove torchviz dependency * Fix gap * Test * Revert "Test" This reverts commit fc15597. * Use RST instead
1 parent 2855a42 commit a9bb29b

File tree

3 files changed

+313
-1
lines changed

3 files changed

+313
-1
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,4 +120,7 @@ cleanup.sh
120120
*.swp
121121

122122
# PyTorch things
123-
*.pt
123+
*.pt
124+
125+
# VSCode
126+
*.vscode

index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,13 @@ Welcome to PyTorch Tutorials
310310
:link: advanced/extend_dispatcher.html
311311
:tags: Extending-PyTorch,Frontend-APIs,C++
312312

313+
.. customcarditem::
314+
:header: Custom Function Tutorial: Double Backward
315+
:card_description: Learn how to write a custom autograd Function that supports double backward.
316+
:image: _static/img/thumbnails/cropped/generic-pytorch-logo.PNG
317+
:link: intermediate/custom_function_double_backward_tutorial.html
318+
:tags: Extending-PyTorch,Frontend-APIs
319+
313320
.. customcarditem::
314321
:header: Custom Function Tutorial: Fusing Convolution and Batch Norm
315322
:card_description: Learn how to create a custom autograd Function that fuses batch norm into a convolution to improve memory usage.
@@ -663,6 +670,7 @@ Additional Resources
663670
:hidden:
664671
:caption: Extending PyTorch
665672

673+
intermediate/custom_function_double_backward
666674
intermediate/custom_function_conv_bn_tutorial
667675
advanced/cpp_extension
668676
advanced/torch_script_custom_ops
Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
Double Backward with Custom Functions
2+
=====================================
3+
4+
It is sometimes useful to run backwards twice through backward graph, for
5+
example to compute higher-order gradients. It takes an understanding of
6+
autograd and some care to support double backwards, however. Functions
7+
that support performing backward a single time are not necessarily
8+
equipped to support double backward. In this tutorial we show how to
9+
write a custom autograd function that supports double backward, and
10+
point out some things to look out for.
11+
12+
13+
When writing a custom autograd function to backward through twice,
14+
it is important to know when operations performed in a custom function
15+
are recorded by autograd, when they aren't, and most importantly, how
16+
`save_for_backward` works with all of this.
17+
18+
Custom functions implicitly affects grad mode in two ways:
19+
20+
- During forward, autograd does not record any the graph for any
21+
operations performed within the forward function. When forward
22+
completes, the backward function of the custom function
23+
becomes the `grad_fn` of each of the forward's outputs
24+
25+
- During backward, autograd records the computation graph used to
26+
compute the backward pass if create_graph is specified
27+
28+
Next, to understand how `save_for_backward` interacts with the above,
29+
we can explore a couple examples:
30+
31+
32+
Saving the Inputs
33+
-------------------------------------------------------------------
34+
Consider this simple squaring function. It saves an input tensor
35+
for backward. Double backward works automatically when autograd
36+
is able to record operations in the backward pass, so there is usually
37+
nothing to worry about when we save an input for backward as
38+
the input should have grad_fn if it is a function of any tensor
39+
that requires grad. This allows the gradients to be properly propagated.
40+
41+
.. code:: python
42+
43+
import torch
44+
45+
class Square(torch.autograd.Function):
46+
@staticmethod
47+
def forward(ctx, x):
48+
# Because we are saving one of the inputs use `save_for_backward`
49+
# Save non-tensors and non-inputs/non-outputs directly on ctx
50+
ctx.save_for_backward(x)
51+
return x**2
52+
53+
@staticmethod
54+
def backward(ctx, grad_out):
55+
# A function support double backward automatically if autograd
56+
# is able to record the computations performed in backward
57+
x, = ctx.saved_tensors
58+
return grad_out * 2 * x
59+
60+
# Use double precision because finite differencing method magnifies errors
61+
x = torch.rand(3, 3, requires_grad=True, dtype=torch.double)
62+
torch.autograd.gradcheck(Square.apply, x)
63+
# Use gradcheck to verify second-order derivatives
64+
torch.autograd.gradgradcheck(Square.apply, x)
65+
66+
67+
We can use torchviz to visualize the graph to see why this works
68+
69+
.. code-block:: python
70+
71+
import torchviz
72+
73+
x = torch.tensor(1., requires_grad=True).clone()
74+
out = Square.apply(x)
75+
grad_x, = torch.autograd.grad(out, x, create_graph=True)
76+
torchviz.make_dot((grad_x, x, out), {"grad_x": grad_x, "x": x, "out": out})
77+
78+
We can see that the gradient wrt to x, is itself a function of x (dout/dx = 2x)
79+
And the graph of this function has been properly constructed
80+
81+
.. image:: https://user-images.githubusercontent.com/13428986/126559699-e04f3cb1-aaf2-4a9a-a83d-b8767d04fbd9.png
82+
:width: 400
83+
84+
85+
Saving the Outputs
86+
-------------------------------------------------------------------
87+
A slight variation on the previous example is to save an output
88+
instead of input. The mechanics are similar because outputs are also
89+
associated with a grad_fn.
90+
91+
.. code-block:: python
92+
93+
class Exp(torch.autograd.Function):
94+
# Simple case where everything goes well
95+
@staticmethod
96+
def forward(ctx, x):
97+
# This time we save the output
98+
result = torch.exp(x)
99+
# Note that we should use `save_for_backward` here when
100+
# the tensor saved is an ouptut (or an input).
101+
ctx.save_for_backward(result)
102+
return result
103+
104+
@staticmethod
105+
def backward(ctx, grad_out):
106+
result, = ctx.saved_tensors
107+
return result * grad_out
108+
109+
x = torch.tensor(1., requires_grad=True, dtype=torch.double).clone()
110+
# Validate our gradients using gradcheck
111+
torch.autograd.gradcheck(Exp.apply, x)
112+
torch.autograd.gradgradcheck(Exp.apply, x)
113+
114+
Use torchviz to visualize the graph:
115+
116+
.. code-block:: python
117+
118+
out = Exp.apply(x)
119+
grad_x, = torch.autograd.grad(out, x, create_graph=True)
120+
torchviz.make_dot((grad_x, x, out), {"grad_x": grad_x, "x": x, "out": out})
121+
122+
.. image:: https://user-images.githubusercontent.com/13428986/126559780-d141f2ba-1ee8-4c33-b4eb-c9877b27a954.png
123+
:width: 332
124+
125+
126+
Saving Intermediate Results
127+
-------------------------------------------------------------------
128+
A more tricky case is when we need to save an intermediate result.
129+
We demonstrate this case by implementing:
130+
131+
.. math::
132+
sinh(x) := \frac{e^x - e^{-x}}{2}
133+
134+
Since the derivative of sinh is cosh, it might be useful to reuse
135+
`exp(x)` and `exp(-x)`, the two intermediate results in forward
136+
in the backward computation.
137+
138+
Intermediate results should not be directly saved and used in backward though.
139+
Because forward is performed in no-grad mode, if an intermediate result
140+
of the forward pass is used to compute gradients in the backward pass
141+
the backward graph of the gradients would not include the operations
142+
that computed the intermediate result. This leads to incorrect gradients.
143+
144+
.. code-block:: python
145+
146+
class Sinh(torch.autograd.Function):
147+
@staticmethod
148+
def forward(ctx, x):
149+
expx = torch.exp(x)
150+
expnegx = torch.exp(-x)
151+
ctx.save_for_backward(expx, expnegx)
152+
# In order to be able to save the intermediate results, a trick is to
153+
# include them as our outputs, so that the backward graph is constructed
154+
return (expx - expnegx) / 2, expx, expnegx
155+
156+
@staticmethod
157+
def backward(ctx, grad_out, _grad_out_exp, _grad_out_negexp):
158+
expx, expnegx = ctx.saved_tensors
159+
grad_input = grad_out * (expx + expnegx) / 2
160+
# We cannot skip accumulating these even though we won't use the outputs
161+
# directly. They will be used later in the second backward.
162+
grad_input += _grad_out_exp * expx
163+
grad_input -= _grad_out_negexp * expnegx
164+
return grad_input
165+
166+
def sinh(x):
167+
# Create a wrapper that only returns the first output
168+
return Sinh.apply(x)[0]
169+
170+
x = torch.rand(3, 3, requires_grad=True, dtype=torch.double)
171+
torch.autograd.gradcheck(sinh, x)
172+
torch.autograd.gradgradcheck(sinh, x)
173+
174+
175+
Use torchviz to visualize the graph:
176+
177+
.. code-block:: python
178+
179+
out = sinh(x)
180+
grad_x, = torch.autograd.grad(out.sum(), x, create_graph=True)
181+
torchviz.make_dot((grad_x, x, out), params={"grad_x": grad_x, "x": x, "out": out})
182+
183+
.. image:: https://user-images.githubusercontent.com/13428986/126560494-e48eba62-be84-4b29-8c90-a7f6f40b1438.png
184+
:width: 460
185+
186+
187+
Saving Intermediate Results: What not to do
188+
-------------------------------------------------------------------
189+
Now we show what happens when we don't also return our intermediate
190+
results as outputs: `grad_x` would not even have a backward graph
191+
because it is purely a function `exp` and `expnegx`, which don't
192+
require grad.
193+
194+
.. code-block:: python
195+
196+
class SinhBad(torch.autograd.Function):
197+
# This is an example of what NOT to do!
198+
@staticmethod
199+
def forward(ctx, x):
200+
expx = torch.exp(x)
201+
expnegx = torch.exp(-x)
202+
ctx.expx = expx
203+
ctx.expnegx = expnegx
204+
return (expx - expnegx) / 2
205+
206+
@staticmethod
207+
def backward(ctx, grad_out):
208+
expx = ctx.expx
209+
expnegx = ctx.expnegx
210+
grad_input = grad_out * (expx + expnegx) / 2
211+
return grad_input
212+
213+
214+
Use torchviz to visualize the graph. Notice that `grad_x` is not
215+
part of the graph!
216+
217+
.. code-block:: python
218+
219+
out = SinhBad.apply(x)
220+
grad_x, = torch.autograd.grad(out.sum(), x, create_graph=True)
221+
torchviz.make_dot((grad_x, x, out), params={"grad_x": grad_x, "x": x, "out": out})
222+
223+
.. image:: https://user-images.githubusercontent.com/13428986/126565889-13992f01-55bc-411a-8aee-05b721fe064a.png
224+
:width: 232
225+
226+
227+
228+
When Backward is not Tracked
229+
-------------------------------------------------------------------
230+
Finally, let's consider an example when it may not be possible for
231+
autograd to track gradients for a functions backward at all.
232+
We can imagine cube_backward to be a function that may require a
233+
non-PyTorch library like SciPy or NumPy, or written as a
234+
C++ extension. The workaround demonstrated here is to create another
235+
custom function CubeBackward where you also manually specify the
236+
backward of cube_backward!
237+
238+
239+
.. code-block:: python
240+
241+
def cube_forward(x):
242+
return x**3
243+
244+
def cube_backward(grad_out, x):
245+
return grad_out * 3 * x**2
246+
247+
def cube_backward_backward(grad_out, sav_grad_out, x):
248+
return grad_out * sav_grad_out * 6 * x
249+
250+
def cube_backward_backward_grad_out(grad_out, x):
251+
return grad_out * 3 * x**2
252+
253+
class Cube(torch.autograd.Function):
254+
@staticmethod
255+
def forward(ctx, x):
256+
ctx.save_for_backward(x)
257+
return cube_forward(x)
258+
259+
@staticmethod
260+
def backward(ctx, grad_out):
261+
x, = ctx.saved_tensors
262+
return CubeBackward.apply(grad_out, x)
263+
264+
class CubeBackward(torch.autograd.Function):
265+
@staticmethod
266+
def forward(ctx, grad_out, x):
267+
ctx.save_for_backward(x, grad_out)
268+
return cube_backward(grad_out, x)
269+
270+
@staticmethod
271+
def backward(ctx, grad_out):
272+
x, sav_grad_out = ctx.saved_tensors
273+
dx = cube_backward_backward(grad_out, sav_grad_out, x)
274+
dgrad_out = cube_backward_backward_grad_out(grad_out, x)
275+
return dgrad_out, dx
276+
277+
x = torch.tensor(2., requires_grad=True, dtype=torch.double)
278+
279+
torch.autograd.gradcheck(Cube.apply, x)
280+
torch.autograd.gradgradcheck(Cube.apply, x)
281+
282+
283+
Use torchviz to visualize the graph:
284+
285+
.. code-block:: python
286+
287+
out = Cube.apply(x)
288+
grad_x, = torch.autograd.grad(out, x, create_graph=True)
289+
torchviz.make_dot((grad_x, x, out), params={"grad_x": grad_x, "x": x, "out": out})
290+
291+
.. image:: https://user-images.githubusercontent.com/13428986/126559935-74526b4d-d419-4983-b1f0-a6ee99428531.png
292+
:width: 352
293+
294+
295+
To conclude, whether double backward works for your custom function
296+
simply depends on whether the backward pass can be tracked by autograd.
297+
With the first two examples we show situations where double backward
298+
works out of the box. With the third and fourth examples, we demonstrate
299+
techniques that enable a backward function to be tracked, when they
300+
otherwise would not be.
301+

0 commit comments

Comments
 (0)