-
-
Save kendricktan/9a776ec6322abaaf03cc9befd35508d4 to your computer and use it in GitHub Desktop.
""" | |
Dynamic Routing Between Capsules | |
https://arxiv.org/abs/1710.09829 | |
""" | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.nn.functional as F | |
import torchvision.transforms as transforms | |
import numpy as np | |
from torch.autograd import Variable | |
from torchvision.datasets.mnist import MNIST | |
from tqdm import tqdm | |
def index_to_one_hot(index_tensor, num_classes=10): | |
""" | |
Converts index value to one hot vector. | |
e.g. [2, 5] (with 10 classes) becomes: | |
[ | |
[0 0 1 0 0 0 0 0 0 0] | |
[0 0 0 0 1 0 0 0 0 0] | |
] | |
""" | |
index_tensor = index_tensor.long() | |
return torch.eye(num_classes).index_select(dim=0, index=index_tensor) | |
def squash_vector(tensor, dim=-1): | |
""" | |
Non-linear 'squashing' to ensure short vectors get shrunk | |
to almost zero length and long vectors get shrunk to a | |
length slightly below 1. | |
""" | |
squared_norm = (tensor**2).sum(dim=dim, keepdim=True) | |
scale = squared_norm / (1 + squared_norm) | |
return scale * tensor / torch.sqrt(squared_norm) | |
def softmax(input, dim=1): | |
""" | |
Apply softmax to specific dimensions. Not released on PyTorch stable yet | |
as of November 6th 2017 | |
https://github.com/pytorch/pytorch/issues/3235 | |
""" | |
transposed_input = input.transpose(dim, len(input.size()) - 1) | |
softmaxed_output = F.softmax( | |
transposed_input.contiguous().view(-1, transposed_input.size(-1))) | |
return softmaxed_output.view(*transposed_input.size()).transpose(dim, len(input.size()) - 1) | |
class CapsuleLayer(nn.Module): | |
def __init__(self, num_capsules, num_routes, in_channels, out_channels, | |
kernel_size=None, stride=None, num_iterations=3): | |
super().__init__() | |
self.num_routes = num_routes | |
self.num_iterations = num_iterations | |
self.num_capsules = num_capsules | |
if num_routes != -1: | |
self.route_weights = nn.Parameter( | |
torch.randn(num_capsules, num_routes, | |
in_channels, out_channels) | |
) | |
else: | |
self.capsules = nn.ModuleList( | |
[nn.Conv2d(in_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=0) | |
for _ in range(num_capsules) | |
] | |
) | |
def forward(self, x): | |
# If routing is defined | |
if self.num_routes != -1: | |
priors = x[None, :, :, None, :] @ self.route_weights[:, None, :, :, :] | |
logits = Variable(torch.zeros(priors.size())).cuda() | |
# Routing algorithm | |
for i in range(self.num_iterations): | |
probs = softmax(logits, dim=2) | |
outputs = squash_vector( | |
probs * priors).sum(dim=2, keepdim=True) | |
if i != self.num_iterations - 1: | |
delta_logits = (priors * outputs).sum(dim=-1, keepdim=True) | |
logits = logits + delta_logits | |
else: | |
outputs = [capsule(x).view(x.size(0), -1, 1) | |
for capsule in self.capsules] | |
outputs = torch.cat(outputs, dim=-1) | |
outputs = squash_vector(outputs) | |
return outputs | |
class MarginLoss(nn.Module): | |
def __init__(self): | |
super().__init__() | |
# Reconstruction as regularization | |
self.reconstruction_loss = nn.MSELoss(size_average=False) | |
def forward(self, images, labels, classes, reconstructions): | |
left = F.relu(0.9 - classes, inplace=True) ** 2 | |
right = F.relu(classes - 0.1, inplace=True) ** 2 | |
margin_loss = labels * left + 0.5 * (1. - labels) * right | |
margin_loss = margin_loss.sum() | |
reconstruction_loss = self.reconstruction_loss(reconstructions, images) | |
return (margin_loss + 0.0005 * reconstruction_loss) / images.size(0) | |
class CapsuleNet(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.conv1 = nn.Conv2d( | |
in_channels=1, out_channels=256, kernel_size=9, stride=1) | |
self.primary_capsules = CapsuleLayer( | |
8, -1, 256, 32, kernel_size=9, stride=2) | |
# 10 is the number of classes | |
self.digit_capsules = CapsuleLayer(10, 32 * 6 * 6, 8, 16) | |
self.decoder = nn.Sequential( | |
nn.Linear(16 * 10, 512), | |
nn.ReLU(inplace=True), | |
nn.Linear(512, 1024), | |
nn.ReLU(inplace=True), | |
nn.Linear(1024, 784), | |
nn.Sigmoid() | |
) | |
def forward(self, x, y=None): | |
x = F.relu(self.conv1(x), inplace=True) | |
x = self.primary_capsules(x) | |
x = self.digit_capsules(x).squeeze().transpose(0, 1) | |
classes = (x ** 2).sum(dim=-1) ** 0.5 | |
classes = F.softmax(classes) | |
if y is None: | |
# In all batches, get the most active capsule | |
_, max_length_indices = classes.max(dim=1) | |
y = Variable(torch.eye(10)).cuda().index_select( | |
dim=0, index=max_length_indices.data) | |
reconstructions = self.decoder((x * y[:, :, None]).view(x.size(0), -1)) | |
return classes, reconstructions | |
if __name__ == '__main__': | |
# Globals | |
CUDA = True | |
EPOCH = 10 | |
# Model | |
model = CapsuleNet() | |
if CUDA: | |
model.cuda() | |
optimizer = optim.Adam(model.parameters()) | |
margin_loss = MarginLoss() | |
train_loader = torch.utils.data.DataLoader( | |
MNIST(root='/tmp', download=True, train=True, | |
transform=transforms.ToTensor()), | |
batch_size=8, shuffle=True) | |
test_loader = torch.utils.data.DataLoader( | |
MNIST(root='/tmp', download=True, train=False, | |
transform=transforms.ToTensor()), | |
batch_size=8, shuffle=True) | |
for e in range(10): | |
# Training | |
train_loss = 0 | |
model.train() | |
for idx, (img, target) in enumerate(tqdm(train_loader, desc='Training')): | |
img = Variable(img) | |
target = Variable(index_to_one_hot(target)) | |
if CUDA: | |
img = img.cuda() | |
target = target.cuda() | |
optimizer.zero_grad() | |
classes, reconstructions = model(img, target) | |
loss = margin_loss(img, target, classes, reconstructions) | |
loss.backward() | |
train_loss += loss.data.cpu()[0] | |
optimizer.step() | |
print('Training:, Avg Loss: {:.4f}'.format(train_loss)) | |
# # Testing | |
correct = 0 | |
test_loss = 0 | |
model.eval() | |
for idx, (img, target) in enumerate(tqdm(test_loader, desc='test set')): | |
img = Variable(img) | |
target_index = target | |
target = Variable(index_to_one_hot(target)) | |
if CUDA: | |
img = img.cuda() | |
target = target.cuda() | |
classes, reconstructions = model(img, target) | |
test_loss += margin_loss(img, target, classes, reconstructions).data.cpu() | |
# Get index of the max log-probability | |
pred = classes.data.max(1, keepdim=True)[1].cpu() | |
correct += pred.eq(target_index.view_as(pred)).cpu().sum() | |
test_loss /= len(test_loader.dataset) | |
correct = 100. * correct / len(test_loader.dataset) | |
print('Test Set: Avg Loss: {:.4f}, Accuracy: {:.4f}'.format( | |
test_loss[0], correct)) |
@Atcold Ooops, my bad. It appears that I've pasted in an outdated version. I've updated the gist now and removed redundancy of tqdm
and print
.
Very well, @kendricktan. Two more remarks.
You can (1) reintroduce tqdm
in the training cycle (as long as you don't print the loss on screen), (2) factor out the feed-forward pass and loss evaluation, which are shared by both training
and testing
procedures. Furthermore, I'd recommend zeroing the gradient after the forward pass, and just before the backward pass, to reduce confusion.
@Atcold done for your remark 1..
As for the 2. I personally think that the state of optimizer should be made explicit (zero'd before anything happens) before anything else happens. Thanks for the feedback 👍
I have a doubt. logits in line num 88 gets the size 10 x 128 x 1152 x 1 x 16. But softmax is done with repect to dim 2 . Should it not be with respect to dim 0 since we have 10 classes. Can you clarify? (assuming batch size is 128)
@balassbals, there are a total of (6 × 6 × 32) 8D capsules u
, which provide their prediction vectors \hat u
. Each capsule input s
is the weighted average of the corresponding \hat u
. The weighting coefficient c
are given by the softmax
over the logits b
, which are as many as the number of capsules in the layer below, i.e. 6 × 6 × 32. Therefore, it is correct to run the softmax
on the 3rd dimension (i.e. dimension number 2). Please, let me know if it is not clear.
@kendricktan, the optimiser is part of the back-propagation algorithm, which starts aftre the forward pass. This is why I would recommend not mixing the two things. I have students who confuse the two...
One last thing, this code does not run when CUDA = False
at line 166. Instead of cuda()
use type_as(other_tensor)
.
@Atcold, I understand what you say. But still I'm confused since the paper says that coupling coeffs between capsule i and all the capsules in the layer above sum to 1 and equation 3 in paper supports this statement. But from what you say my understanding is that the coeffs between all capsules in layer l and capsule j in layer l + 1 sum to one. Can you clarify?
@balassbals, you are correct. Today I gave a speech at NYU, about this paper, and people pointed out that the softmax
is done across the fist dimension (i.e. dimension number 0). I missed this the first time I read the paper. My bad. So you are correct, there is a mistake in this implementation.
@kendricktan did you follow the conversation? If so, please fix.
@Atcold, But when I do across dim 0(10 classes), I dont get the expected results. Another implementation I saw in Pytorch uses F.softmax wrongly. Actually I implemented it myself first but I'm not getting the results. So I'm looking for some working version in Pytorch.
Also, why is there a softmax()
at line L152
? This should simply be the capsule's norm! Correct?
@balassbals I have not found any working PyTorch implementations that softmax across the 10 classes (only across the 1152 routes, which does not match the paper). Have you discovered anything since?
hello,
i have some experience about capnet written in tensorflow but i have no idea about pytorch. can you help me?
i want to input data which has size of (224,224,3) and target will be binary 0 or 1 so for this kind of data what kind of modification i have to make?
thanks in advance.
Also,
tqdm
andprint
make a mess on screen.