Skip to content

Commit 115c537

Browse files
authored
add intuitive pytorch code for combo example
1 parent a7d8877 commit 115c537

File tree

1 file changed

+45
-16
lines changed

1 file changed

+45
-16
lines changed

README.md

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -241,22 +241,51 @@ And we get:
241241

242242
## Intuitively Build Complex Architectures
243243

244-
Now we will visit a more complex example that combines several of the above operations. You will notice that as we add
245-
more and more complexity to our network, the Torch code becomes more and more verbose.
246-
On the other hand, thanks to autograd, the complexity of our PyTorch code does not increase at all.
244+
Now we will visit a more complex example that combines several of the above operations.
245+
The graph below is a random network that I created using the Torch [nngraph](https://github.com/torch/nngraph) package,
246+
and the Torch code definition using nngraph can be found [here](https://github.com/amdegroot/pytorch-containers/blob/master/complex_graph.lua) and a raw Torch implementation can be found [here](https://github.com/amdegroot/pytorch-containers/blob/master/complex_net.lua) for comparison to the PyTorch code that follows.
247247

248248
<img src= "https://github.com/amdegroot/pytorch-containers/blob/master/doc/complex_example.png" width="600px"/>
249249

250-
251-
252-
253-
254-
255-
256-
257-
258-
259-
260-
261-
262-
250+
```Python
251+
class Branch(nn.Module):
252+
def __init__(self,b2):
253+
super(Branch, self).__init__()
254+
self.b = nn.MaxPool2d(kernel_size=2, stride=2)
255+
self.b2 = b2
256+
257+
def forward(self,x):
258+
x = self.b(x)
259+
y = [self.b2(x).view(-1), self.b2(x).view(-1)]
260+
z = torch.cat((y[0],y[1]))
261+
return z
262+
263+
class ComplexNet(nn.Module):
264+
def __init__(self, m1, m2):
265+
super(ComplexNet, self).__init__()
266+
self.net1 = m1
267+
self.net2 = m2
268+
self.net3 = nn.Conv2d(128,256,kernel_size=3,padding=1)
269+
self.branch1 = Branch(nn.Conv2d(64,64,kernel_size=3,padding=1))
270+
self.branch2 = Branch(nn.Conv2d(128,256,kernel_size=3, padding=1))
271+
272+
def forward(self, x):
273+
x = self.net1(x)
274+
x1 = self.branch1(x)
275+
y = self.net2(x)
276+
x2 = self.branch2(y)
277+
x3 = self.net3(y).view(-1)
278+
output = torch.cat((x1,x2,x3),0)
279+
return output
280+
281+
def make_layers(params, ch):
282+
layers = []
283+
channels = ch
284+
for p in params:
285+
conv2d = nn.Conv2d(channels, p, kernel_size=3, padding=1)
286+
layers += [conv2d, nn.ReLU(inplace=True)]
287+
channels = p
288+
return nn.Sequential(*layers)
289+
290+
net = ComplexNet(make_layers([64,64],3),make_layers([128,128],64))
291+
```

0 commit comments

Comments
 (0)