Skip to content

Commit 086699f

Browse files
authored
Update complex example/ side notes on tables
1 parent 115c537 commit 086699f

File tree

1 file changed

+39
-17
lines changed

1 file changed

+39
-17
lines changed

README.md

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# pytorch-containers
22

3-
This repository aims to help former Torchies more seemlessly transition to the "Containerless" world of
3+
This repository aims to help former Torchies more seamlessly transition to the "Containerless" world of
44
[PyTorch](https://github.com/pytorch/pytorch)
55
by providing a list of PyTorch implementations of [Torch Table Layers](https://github.com/torch/nn/blob/master/doc/table.md).
66

@@ -68,7 +68,7 @@ flexibility as your architectures become more complex, and it's also a lot easie
6868
remembering the exact functionality of ConcatTable, or any of the other tables for that matter.
6969

7070
Two other things to note:
71-
- To work with autograd, we must wrap our input in a `Variable`
71+
- To work with autograd, we must wrap our input in a `Variable` (we can also pass a python iterable of Variables)
7272
- PyTorch requires us to add a batch dimension which is why we call `.unsqueeze(0)` on the input
7373

7474

@@ -202,7 +202,7 @@ class TableModule(nn.Module):
202202
super(TableModule,self).__init__()
203203

204204
def forward(self,x1,x2):
205-
x_sum = x1+x2
205+
x_sum = x1+x2 # could use sum() if input given as python iterable
206206
x_sub = x1-x2
207207
x_div = x1/x2
208208
x_mul = x1*x2
@@ -242,42 +242,62 @@ And we get:
242242
## Intuitively Build Complex Architectures
243243

244244
Now we will visit a more complex example that combines several of the above operations.
245-
The graph below is a random network that I created using the Torch [nngraph](https://github.com/torch/nngraph) package,
246-
and the Torch code definition using nngraph can be found [here](https://github.com/amdegroot/pytorch-containers/blob/master/complex_graph.lua) and a raw Torch implementation can be found [here](https://github.com/amdegroot/pytorch-containers/blob/master/complex_net.lua) for comparison to the PyTorch code that follows.
245+
The graph below is a random network that I created using the Torch [nngraph](https://github.com/torch/nngraph) package. The Torch model definition using nngraph can be found [here](https://github.com/amdegroot/pytorch-containers/blob/master/complex_graph.lua) and a raw Torch implementation can be found [here](https://github.com/amdegroot/pytorch-containers/blob/master/complex_net.lua) for comparison to the PyTorch code that follows.
247246

248247
<img src= "https://github.com/amdegroot/pytorch-containers/blob/master/doc/complex_example.png" width="600px"/>
249248

250249
```Python
251250
class Branch(nn.Module):
252251
def __init__(self,b2):
253252
super(Branch, self).__init__()
253+
"""
254+
Upon closer examination of the structure, note a
255+
MaxPool2d with the same params is used in each branch,
256+
so we can just reuse this and pass in the
257+
conv layer that is repeated in parallel right after
258+
it (reusing it as well).
259+
"""
254260
self.b = nn.MaxPool2d(kernel_size=2, stride=2)
255261
self.b2 = b2
256262

257263
def forward(self,x):
258264
x = self.b(x)
259-
y = [self.b2(x).view(-1), self.b2(x).view(-1)]
260-
z = torch.cat((y[0],y[1]))
265+
y = [self.b2(x).view(-1), self.b2(x).view(-1)] # pytorch 'ParallelTable'
266+
z = torch.cat((y[0],y[1])) # pytorch 'JoinTable'
261267
return z
262-
268+
```
269+
Now that we have a branch class general enough to handle both branches, we can define the base segments
270+
and piece it all together in a very natural way.
271+
272+
```Python
263273
class ComplexNet(nn.Module):
264274
def __init__(self, m1, m2):
265275
super(ComplexNet, self).__init__()
266-
self.net1 = m1
267-
self.net2 = m2
268-
self.net3 = nn.Conv2d(128,256,kernel_size=3,padding=1)
269-
self.branch1 = Branch(nn.Conv2d(64,64,kernel_size=3,padding=1))
276+
# define each piece of our network shown above
277+
self.net1 = m1 # segment 1 from VGG
278+
self.net2 = m2 #segment 2 from VGG
279+
self.net3 = nn.Conv2d(128,256,kernel_size=3,padding=1) # last layer
280+
self.branch1 = Branch(nn.Conv2d(64,64,kernel_size=3,padding=1))
270281
self.branch2 = Branch(nn.Conv2d(128,256,kernel_size=3, padding=1))
271282

272283
def forward(self, x):
284+
"""
285+
Here we see that autograd allows us to safely reuse Variables in
286+
defining the computational graph. We could also reuse Modules or even
287+
use loops or conditional statements.
288+
Note: Some of this could be condensed, but it is laid out the way it
289+
is for clarity.
290+
"""
273291
x = self.net1(x)
274-
x1 = self.branch1(x)
275-
y = self.net2(x)
276-
x2 = self.branch2(y)
292+
x1 = self.branch1(x) # SplitTable (implicitly)
293+
y = self.net2(x)
294+
x2 = self.branch2(y) # SplitTable (implicitly)
277295
x3 = self.net3(y).view(-1)
278-
output = torch.cat((x1,x2,x3),0)
296+
output = torch.cat((x1,x2,x3),0) # JoinTable
279297
return output
280-
298+
```
299+
This is a loop to define our VGG conv layers derived from [pytorch/vision](https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py). (maybe a little overkill for our small case)
300+
```Python
281301
def make_layers(params, ch):
282302
layers = []
283303
channels = ch
@@ -289,3 +309,5 @@ def make_layers(params, ch):
289309

290310
net = ComplexNet(make_layers([64,64],3),make_layers([128,128],64))
291311
```
312+
This documented python code can be found [here](https://github.com/amdegroot/pytorch-containers/blob/master/complex_net.py).
313+

0 commit comments

Comments
 (0)