|
| 1 | +Double Backward with Custom Functions |
| 2 | +===================================== |
| 3 | + |
| 4 | +It is sometimes useful to run backwards twice through backward graph, for |
| 5 | +example to compute higher-order gradients. It takes an understanding of |
| 6 | +autograd and some care to support double backwards, however. Functions |
| 7 | +that support performing backward a single time are not necessarily |
| 8 | +equipped to support double backward. In this tutorial we show how to |
| 9 | +write a custom autograd function that supports double backward, and |
| 10 | +point out some things to look out for. |
| 11 | + |
| 12 | + |
| 13 | +When writing a custom autograd function to backward through twice, |
| 14 | +it is important to know when operations performed in a custom function |
| 15 | +are recorded by autograd, when they aren't, and most importantly, how |
| 16 | +`save_for_backward` works with all of this. |
| 17 | + |
| 18 | +Custom functions implicitly affects grad mode in two ways: |
| 19 | + |
| 20 | +- During forward, autograd does not record any the graph for any |
| 21 | + operations performed within the forward function. When forward |
| 22 | + completes, the backward function of the custom function |
| 23 | + becomes the `grad_fn` of each of the forward's outputs |
| 24 | + |
| 25 | +- During backward, autograd records the computation graph used to |
| 26 | + compute the backward pass if create_graph is specified |
| 27 | + |
| 28 | +Next, to understand how `save_for_backward` interacts with the above, |
| 29 | +we can explore a couple examples: |
| 30 | + |
| 31 | + |
| 32 | +Saving the Inputs |
| 33 | +------------------------------------------------------------------- |
| 34 | +Consider this simple squaring function. It saves an input tensor |
| 35 | +for backward. Double backward works automatically when autograd |
| 36 | +is able to record operations in the backward pass, so there is usually |
| 37 | +nothing to worry about when we save an input for backward as |
| 38 | +the input should have grad_fn if it is a function of any tensor |
| 39 | +that requires grad. This allows the gradients to be properly propagated. |
| 40 | + |
| 41 | +.. code:: python |
| 42 | +
|
| 43 | + import torch |
| 44 | +
|
| 45 | + class Square(torch.autograd.Function): |
| 46 | + @staticmethod |
| 47 | + def forward(ctx, x): |
| 48 | + # Because we are saving one of the inputs use `save_for_backward` |
| 49 | + # Save non-tensors and non-inputs/non-outputs directly on ctx |
| 50 | + ctx.save_for_backward(x) |
| 51 | + return x**2 |
| 52 | +
|
| 53 | + @staticmethod |
| 54 | + def backward(ctx, grad_out): |
| 55 | + # A function support double backward automatically if autograd |
| 56 | + # is able to record the computations performed in backward |
| 57 | + x, = ctx.saved_tensors |
| 58 | + return grad_out * 2 * x |
| 59 | +
|
| 60 | + # Use double precision because finite differencing method magnifies errors |
| 61 | + x = torch.rand(3, 3, requires_grad=True, dtype=torch.double) |
| 62 | + torch.autograd.gradcheck(Square.apply, x) |
| 63 | + # Use gradcheck to verify second-order derivatives |
| 64 | + torch.autograd.gradgradcheck(Square.apply, x) |
| 65 | +
|
| 66 | +
|
| 67 | +We can use torchviz to visualize the graph to see why this works |
| 68 | + |
| 69 | +.. code-block:: python |
| 70 | +
|
| 71 | + import torchviz |
| 72 | +
|
| 73 | + x = torch.tensor(1., requires_grad=True).clone() |
| 74 | + out = Square.apply(x) |
| 75 | + grad_x, = torch.autograd.grad(out, x, create_graph=True) |
| 76 | + torchviz.make_dot((grad_x, x, out), {"grad_x": grad_x, "x": x, "out": out}) |
| 77 | +
|
| 78 | +We can see that the gradient wrt to x, is itself a function of x (dout/dx = 2x) |
| 79 | +And the graph of this function has been properly constructed |
| 80 | + |
| 81 | +.. image:: https://user-images.githubusercontent.com/13428986/126559699-e04f3cb1-aaf2-4a9a-a83d-b8767d04fbd9.png |
| 82 | + :width: 400 |
| 83 | + |
| 84 | + |
| 85 | +Saving the Outputs |
| 86 | +------------------------------------------------------------------- |
| 87 | +A slight variation on the previous example is to save an output |
| 88 | +instead of input. The mechanics are similar because outputs are also |
| 89 | +associated with a grad_fn. |
| 90 | + |
| 91 | +.. code-block:: python |
| 92 | +
|
| 93 | + class Exp(torch.autograd.Function): |
| 94 | + # Simple case where everything goes well |
| 95 | + @staticmethod |
| 96 | + def forward(ctx, x): |
| 97 | + # This time we save the output |
| 98 | + result = torch.exp(x) |
| 99 | + # Note that we should use `save_for_backward` here when |
| 100 | + # the tensor saved is an ouptut (or an input). |
| 101 | + ctx.save_for_backward(result) |
| 102 | + return result |
| 103 | +
|
| 104 | + @staticmethod |
| 105 | + def backward(ctx, grad_out): |
| 106 | + result, = ctx.saved_tensors |
| 107 | + return result * grad_out |
| 108 | +
|
| 109 | + x = torch.tensor(1., requires_grad=True, dtype=torch.double).clone() |
| 110 | + # Validate our gradients using gradcheck |
| 111 | + torch.autograd.gradcheck(Exp.apply, x) |
| 112 | + torch.autograd.gradgradcheck(Exp.apply, x) |
| 113 | +
|
| 114 | +Use torchviz to visualize the graph: |
| 115 | + |
| 116 | +.. code-block:: python |
| 117 | +
|
| 118 | + out = Exp.apply(x) |
| 119 | + grad_x, = torch.autograd.grad(out, x, create_graph=True) |
| 120 | + torchviz.make_dot((grad_x, x, out), {"grad_x": grad_x, "x": x, "out": out}) |
| 121 | +
|
| 122 | +.. image:: https://user-images.githubusercontent.com/13428986/126559780-d141f2ba-1ee8-4c33-b4eb-c9877b27a954.png |
| 123 | + :width: 332 |
| 124 | + |
| 125 | + |
| 126 | +Saving Intermediate Results |
| 127 | +------------------------------------------------------------------- |
| 128 | +A more tricky case is when we need to save an intermediate result. |
| 129 | +We demonstrate this case by implementing: |
| 130 | + |
| 131 | +.. math:: |
| 132 | + sinh(x) := \frac{e^x - e^{-x}}{2} |
| 133 | +
|
| 134 | +Since the derivative of sinh is cosh, it might be useful to reuse |
| 135 | +`exp(x)` and `exp(-x)`, the two intermediate results in forward |
| 136 | +in the backward computation. |
| 137 | + |
| 138 | +Intermediate results should not be directly saved and used in backward though. |
| 139 | +Because forward is performed in no-grad mode, if an intermediate result |
| 140 | +of the forward pass is used to compute gradients in the backward pass |
| 141 | +the backward graph of the gradients would not include the operations |
| 142 | +that computed the intermediate result. This leads to incorrect gradients. |
| 143 | + |
| 144 | +.. code-block:: python |
| 145 | +
|
| 146 | + class Sinh(torch.autograd.Function): |
| 147 | + @staticmethod |
| 148 | + def forward(ctx, x): |
| 149 | + expx = torch.exp(x) |
| 150 | + expnegx = torch.exp(-x) |
| 151 | + ctx.save_for_backward(expx, expnegx) |
| 152 | + # In order to be able to save the intermediate results, a trick is to |
| 153 | + # include them as our outputs, so that the backward graph is constructed |
| 154 | + return (expx - expnegx) / 2, expx, expnegx |
| 155 | +
|
| 156 | + @staticmethod |
| 157 | + def backward(ctx, grad_out, _grad_out_exp, _grad_out_negexp): |
| 158 | + expx, expnegx = ctx.saved_tensors |
| 159 | + grad_input = grad_out * (expx + expnegx) / 2 |
| 160 | + # We cannot skip accumulating these even though we won't use the outputs |
| 161 | + # directly. They will be used later in the second backward. |
| 162 | + grad_input += _grad_out_exp * expx |
| 163 | + grad_input -= _grad_out_negexp * expnegx |
| 164 | + return grad_input |
| 165 | +
|
| 166 | + def sinh(x): |
| 167 | + # Create a wrapper that only returns the first output |
| 168 | + return Sinh.apply(x)[0] |
| 169 | +
|
| 170 | + x = torch.rand(3, 3, requires_grad=True, dtype=torch.double) |
| 171 | + torch.autograd.gradcheck(sinh, x) |
| 172 | + torch.autograd.gradgradcheck(sinh, x) |
| 173 | +
|
| 174 | +
|
| 175 | +Use torchviz to visualize the graph: |
| 176 | + |
| 177 | +.. code-block:: python |
| 178 | +
|
| 179 | + out = sinh(x) |
| 180 | + grad_x, = torch.autograd.grad(out.sum(), x, create_graph=True) |
| 181 | + torchviz.make_dot((grad_x, x, out), params={"grad_x": grad_x, "x": x, "out": out}) |
| 182 | +
|
| 183 | +.. image:: https://user-images.githubusercontent.com/13428986/126560494-e48eba62-be84-4b29-8c90-a7f6f40b1438.png |
| 184 | + :width: 460 |
| 185 | + |
| 186 | + |
| 187 | +Saving Intermediate Results: What not to do |
| 188 | +------------------------------------------------------------------- |
| 189 | +Now we show what happens when we don't also return our intermediate |
| 190 | +results as outputs: `grad_x` would not even have a backward graph |
| 191 | +because it is purely a function `exp` and `expnegx`, which don't |
| 192 | +require grad. |
| 193 | + |
| 194 | +.. code-block:: python |
| 195 | +
|
| 196 | + class SinhBad(torch.autograd.Function): |
| 197 | + # This is an example of what NOT to do! |
| 198 | + @staticmethod |
| 199 | + def forward(ctx, x): |
| 200 | + expx = torch.exp(x) |
| 201 | + expnegx = torch.exp(-x) |
| 202 | + ctx.expx = expx |
| 203 | + ctx.expnegx = expnegx |
| 204 | + return (expx - expnegx) / 2 |
| 205 | +
|
| 206 | + @staticmethod |
| 207 | + def backward(ctx, grad_out): |
| 208 | + expx = ctx.expx |
| 209 | + expnegx = ctx.expnegx |
| 210 | + grad_input = grad_out * (expx + expnegx) / 2 |
| 211 | + return grad_input |
| 212 | +
|
| 213 | +
|
| 214 | +Use torchviz to visualize the graph. Notice that `grad_x` is not |
| 215 | +part of the graph! |
| 216 | + |
| 217 | +.. code-block:: python |
| 218 | +
|
| 219 | + out = SinhBad.apply(x) |
| 220 | + grad_x, = torch.autograd.grad(out.sum(), x, create_graph=True) |
| 221 | + torchviz.make_dot((grad_x, x, out), params={"grad_x": grad_x, "x": x, "out": out}) |
| 222 | +
|
| 223 | +.. image:: https://user-images.githubusercontent.com/13428986/126565889-13992f01-55bc-411a-8aee-05b721fe064a.png |
| 224 | + :width: 232 |
| 225 | + |
| 226 | + |
| 227 | + |
| 228 | +When Backward is not Tracked |
| 229 | +------------------------------------------------------------------- |
| 230 | +Finally, let's consider an example when it may not be possible for |
| 231 | +autograd to track gradients for a functions backward at all. |
| 232 | +We can imagine cube_backward to be a function that may require a |
| 233 | +non-PyTorch library like SciPy or NumPy, or written as a |
| 234 | +C++ extension. The workaround demonstrated here is to create another |
| 235 | +custom function CubeBackward where you also manually specify the |
| 236 | +backward of cube_backward! |
| 237 | + |
| 238 | + |
| 239 | +.. code-block:: python |
| 240 | +
|
| 241 | + def cube_forward(x): |
| 242 | + return x**3 |
| 243 | +
|
| 244 | + def cube_backward(grad_out, x): |
| 245 | + return grad_out * 3 * x**2 |
| 246 | +
|
| 247 | + def cube_backward_backward(grad_out, sav_grad_out, x): |
| 248 | + return grad_out * sav_grad_out * 6 * x |
| 249 | +
|
| 250 | + def cube_backward_backward_grad_out(grad_out, x): |
| 251 | + return grad_out * 3 * x**2 |
| 252 | +
|
| 253 | + class Cube(torch.autograd.Function): |
| 254 | + @staticmethod |
| 255 | + def forward(ctx, x): |
| 256 | + ctx.save_for_backward(x) |
| 257 | + return cube_forward(x) |
| 258 | +
|
| 259 | + @staticmethod |
| 260 | + def backward(ctx, grad_out): |
| 261 | + x, = ctx.saved_tensors |
| 262 | + return CubeBackward.apply(grad_out, x) |
| 263 | +
|
| 264 | + class CubeBackward(torch.autograd.Function): |
| 265 | + @staticmethod |
| 266 | + def forward(ctx, grad_out, x): |
| 267 | + ctx.save_for_backward(x, grad_out) |
| 268 | + return cube_backward(grad_out, x) |
| 269 | +
|
| 270 | + @staticmethod |
| 271 | + def backward(ctx, grad_out): |
| 272 | + x, sav_grad_out = ctx.saved_tensors |
| 273 | + dx = cube_backward_backward(grad_out, sav_grad_out, x) |
| 274 | + dgrad_out = cube_backward_backward_grad_out(grad_out, x) |
| 275 | + return dgrad_out, dx |
| 276 | +
|
| 277 | + x = torch.tensor(2., requires_grad=True, dtype=torch.double) |
| 278 | +
|
| 279 | + torch.autograd.gradcheck(Cube.apply, x) |
| 280 | + torch.autograd.gradgradcheck(Cube.apply, x) |
| 281 | +
|
| 282 | +
|
| 283 | +Use torchviz to visualize the graph: |
| 284 | + |
| 285 | +.. code-block:: python |
| 286 | +
|
| 287 | + out = Cube.apply(x) |
| 288 | + grad_x, = torch.autograd.grad(out, x, create_graph=True) |
| 289 | + torchviz.make_dot((grad_x, x, out), params={"grad_x": grad_x, "x": x, "out": out}) |
| 290 | +
|
| 291 | +.. image:: https://user-images.githubusercontent.com/13428986/126559935-74526b4d-d419-4983-b1f0-a6ee99428531.png |
| 292 | + :width: 352 |
| 293 | + |
| 294 | + |
| 295 | +To conclude, whether double backward works for your custom function |
| 296 | +simply depends on whether the backward pass can be tracked by autograd. |
| 297 | +With the first two examples we show situations where double backward |
| 298 | +works out of the box. With the third and fourth examples, we demonstrate |
| 299 | +techniques that enable a backward function to be tracked, when they |
| 300 | +otherwise would not be. |
| 301 | + |
0 commit comments