Skip to content

Instantly share code, notes, and snippets.

@bkj
Created December 14, 2017 19:59
Show Gist options
  • Save bkj/435ded4afeedfa4024c28f4baaa745e0 to your computer and use it in GitHub Desktop.
Save bkj/435ded4afeedfa4024c28f4baaa745e0 to your computer and use it in GitHub Desktop.
pytorch + dask
#!/usr/bin/env python
"""
pytorch-dask.py
"""
import torch
from torch import nn
from torch.nn import functional as F
class DaskNet(nn.Module):
def __init__(self, graph):
super(DaskNet, self).__init__()
self.graph = graph
self.compile()
def compile(self):
for k,v in self.graph.items():
if v is None:
continue
elif isinstance(v[0], nn.Module):
layer = v[0]
self.add_module(str(k), layer)
def forward(self, x, layer='output'):
self.graph['data'] = x
return get(self.graph, layer)
# --
# Example
model = DaskNet({
0 : (nn.Conv2d(1, 32, kernel_size=3), "data"),
1 : (nn.ReLU(), 0),
2 : (nn.Conv2d(32, 64, kernel_size=3), 1),
3 : (nn.ReLU(), 2),
4 : (nn.MaxPool2d(2), 3),
5 : (lambda x: x.view(x.size(0), -1), 4), # << Lambda functions!
6 : (lambda x: torch.cat(x, dim=-1), [5, 5]), # << Arbitrary graph structure!
"output" : (nn.Linear(2 * 9216, 10), 6)
})
X = Variable(torch.randn(5, 1, 28, 28))
# Run model
model(X).size()
# Get intermediate layers
model(X, layer=2).size()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment