-
-
Save crazyoscarchang/c9a11b67c420202da1f26e0d20786750 to your computer and use it in GitHub Desktop.
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
import torchvision | |
from torchvision import datasets, transforms | |
import math | |
import numpy as np | |
# Hardcoded variables for hyperfan init | |
hardcoded_input_size = 3 | |
hardcoded_n_classes = 10 | |
hardcoded_hyperfanin = [hardcoded_input_size]*hardcoded_input_size + [96]*96*4 + [192]*(192*8 + 2*hardcoded_n_classes) | |
hardcoded_hyperfanout = [96]*(hardcoded_input_size + 96*2) + [192]*(192*9) + [hardcoded_n_classes]*2*hardcoded_n_classes | |
hardcoded_receptive = lambda i: 9 if i < hardcoded_input_size + 192*8 else 1 | |
def hyperfaninWi_init(i): | |
def hyperfanin_init(Wi): | |
fan_out, fan_in = Wi.size(0), Wi.size(1) | |
bound = math.sqrt(3*2 / (fan_in * hardcoded_hyperfanin[i]) / hardcoded_receptive(i)) | |
Wi.uniform_(-bound, bound) | |
return Wi | |
return hyperfanin_init | |
def hyperfanoutWi_init(i): | |
def hyperfanout_init(Wi): | |
fan_out, fan_in = Wi.size(0), Wi.size(1) | |
bound = math.sqrt(3*2 / (fan_in * hardcoded_hyperfanout[i]) / hardcoded_receptive(i)) | |
Wi.uniform_(-bound, bound) | |
return Wi | |
return hyperfanout_init | |
def fanin_uniform(W): | |
fan_out, fan_in = W.size(0), W.size(1) | |
bound = math.sqrt(3*2 / fan_in) | |
W.uniform_(-bound, bound) | |
return W | |
def embed_uniform(e): | |
bound = math.sqrt(3) | |
e.uniform_(-bound, bound) | |
return e | |
# Adapted from https://github.com/StefOe/all-conv-pytorch/blob/master/allconv.py | |
class AllConvNet(nn.Module): | |
def __init__(self, input_size, n_classes): | |
super(AllConvNet, self).__init__() | |
self.input_size = input_size | |
self.n_classes = n_classes | |
def forward(self, x): | |
x_drop = F.dropout(x, .2) | |
conv1_out = F.relu(F.conv2d(x_drop, self.conv1_weight, self.conv1_bias, padding=1)) | |
conv2_out = F.relu(F.conv2d(conv1_out, self.conv2_weight, self.conv2_bias, padding=1)) | |
conv3_out = F.relu(F.conv2d(conv2_out, self.conv3_weight, self.conv3_bias, padding=1, stride=2)) | |
conv3_out_drop = F.dropout(conv3_out, .5) | |
conv4_out = F.relu(F.conv2d(conv3_out_drop, self.conv4_weight, self.conv4_bias, padding=1)) | |
conv5_out = F.relu(F.conv2d(conv4_out, self.conv5_weight, self.conv5_bias, padding=1)) | |
conv6_out = F.relu(F.conv2d(conv5_out, self.conv6_weight, self.conv6_bias, padding=1, stride=2)) | |
conv6_out_drop = F.dropout(conv6_out, .5) | |
conv7_out = F.relu(F.conv2d(conv6_out_drop, self.conv7_weight, self.conv7_bias, padding=1)) | |
conv8_out = F.relu(F.conv2d(conv7_out, self.conv8_weight, self.conv8_bias)) | |
class_out = F.relu(F.conv2d(conv8_out, self.class_conv_weight, self.class_conv_bias)) | |
pool_out = F.adaptive_avg_pool2d(class_out, 1) | |
pool_out.squeeze_(-1) | |
pool_out.squeeze_(-1) | |
return pool_out | |
class HyperNN(AllConvNet): | |
def __init__(self, input_size, n_classes, embed_size, embedW_init_scheme, | |
hyperWi_init_scheme, hyperWout_init_scheme, device): | |
super().__init__(input_size, n_classes) | |
# Initialize the fixed parameters | |
self.num_kernels = input_size + 2*n_classes + 1920 # 96*2 + 192 + (192*2)*4 | |
self.weight_embeddings = embedW_init_scheme(torch.zeros(self.num_kernels, embed_size).to(device)) | |
# Initialize the trainable weight parameters | |
Wi = torch.zeros(self.num_kernels, embed_size, embed_size) | |
for i in range(self.num_kernels): | |
Wi[i] = hyperWi_init_scheme(i)(Wi[i]) | |
Bi = torch.zeros(self.num_kernels, embed_size) | |
Wout = hyperWout_init_scheme(torch.zeros(96*9, embed_size)) | |
Bout = torch.zeros(96*9) | |
# Register the trainable weight parameters | |
self.Wi = nn.Parameter(Wi) | |
self.Bi = nn.Parameter(Bi) | |
self.Wout = nn.Parameter(Wout) | |
self.Bout = nn.Parameter(Bout) | |
# Initialize and register the trainable bias parameters | |
self.conv1_bias = nn.Parameter(torch.zeros(96)) | |
self.conv2_bias = nn.Parameter(torch.zeros(96)) | |
self.conv3_bias = nn.Parameter(torch.zeros(96)) | |
self.conv4_bias = nn.Parameter(torch.zeros(192)) | |
self.conv5_bias = nn.Parameter(torch.zeros(192)) | |
self.conv6_bias = nn.Parameter(torch.zeros(192)) | |
self.conv7_bias = nn.Parameter(torch.zeros(192)) | |
self.conv8_bias = nn.Parameter(torch.zeros(192)) | |
self.class_conv_bias = nn.Parameter(torch.zeros(n_classes)) | |
def forward(self, x): | |
# Generate main weights from HyperNet's parameters | |
idx = 0; jump = self.input_size | |
self.conv1_weight = ((self.Wout @ \ | |
(self.Wi[idx:idx+jump] @ self.weight_embeddings[idx:idx+jump].unsqueeze(2) + \ | |
self.Bi[idx:idx+jump].unsqueeze(2))).squeeze(2) \ | |
+ self.Bout.unsqueeze(0)).view(96, self.input_size, 3, 3) | |
idx += jump; jump = 96 | |
self.conv2_weight = ((self.Wout @ \ | |
(self.Wi[idx:idx+jump] @ self.weight_embeddings[idx:idx+jump].unsqueeze(2) + \ | |
self.Bi[idx:idx+jump].unsqueeze(2))).squeeze(2) \ | |
+ self.Bout.unsqueeze(0)).view(96, 96, 3, 3) | |
idx += jump; | |
self.conv3_weight = ((self.Wout @ \ | |
(self.Wi[idx:idx+jump] @ self.weight_embeddings[idx:idx+jump].unsqueeze(2) + \ | |
self.Bi[idx:idx+jump].unsqueeze(2))).squeeze(2) \ | |
+ self.Bout.unsqueeze(0)).view(96, 96, 3, 3) | |
idx += jump; jump = 192 | |
self.conv4_weight = ((self.Wout @ \ | |
(self.Wi[idx:idx+jump] @ self.weight_embeddings[idx:idx+jump].unsqueeze(2) + \ | |
self.Bi[idx:idx+jump].unsqueeze(2))).squeeze(2) \ | |
+ self.Bout.unsqueeze(0)).view(192, 96, 3, 3) | |
idx += jump; jump = 192*2 | |
self.conv5_weight = ((self.Wout @ \ | |
(self.Wi[idx:idx+jump] @ self.weight_embeddings[idx:idx+jump].unsqueeze(2) + \ | |
self.Bi[idx:idx+jump].unsqueeze(2))).squeeze(2) \ | |
+ self.Bout.unsqueeze(0)).view(192, 192, 3, 3) | |
idx += jump; | |
self.conv6_weight = ((self.Wout @ \ | |
(self.Wi[idx:idx+jump] @ self.weight_embeddings[idx:idx+jump].unsqueeze(2) + \ | |
self.Bi[idx:idx+jump].unsqueeze(2))).squeeze(2) \ | |
+ self.Bout.unsqueeze(0)).view(192, 192, 3, 3) | |
idx += jump; | |
self.conv7_weight = ((self.Wout @ \ | |
(self.Wi[idx:idx+jump] @ self.weight_embeddings[idx:idx+jump].unsqueeze(2) + \ | |
self.Bi[idx:idx+jump].unsqueeze(2))).squeeze(2) \ | |
+ self.Bout.unsqueeze(0)).view(192, 192, 3, 3) | |
idx += jump; | |
self.conv8_weight = ((self.Wout @ \ | |
(self.Wi[idx:idx+jump] @ self.weight_embeddings[idx:idx+jump].unsqueeze(2) + \ | |
self.Bi[idx:idx+jump].unsqueeze(2))).squeeze(2) \ | |
+ self.Bout.unsqueeze(0)).view(192, 192, 3, 3)[:, :, :1, :1] | |
idx += jump; jump = self.n_classes * 2 | |
self.class_conv_weight = ((self.Wout @ \ | |
(self.Wi[idx:idx+jump] @ self.weight_embeddings[idx:idx+jump].unsqueeze(2) + \ | |
self.Bi[idx:idx+jump].unsqueeze(2))).squeeze(2) \ | |
+ self.Bout.unsqueeze(0)).view(self.n_classes, 192, 3, 3)[:, :, :1, :1] | |
return super().forward(x) | |
# Configuration | |
device = 'cuda:0' | |
embed_size = 50 | |
embedW_init_scheme = embed_uniform | |
hyperWi_init_scheme = hyperfaninWi_init # hyperfanoutWi_init | |
hyperWout_init_scheme = fanin_uniform | |
lr = 0.0005 | |
training_batch_size = 100 | |
test_batch_size = 1000 | |
epochs = 500 | |
log_interval = 100 | |
seed = 123 | |
torch.manual_seed(seed) | |
train_criterion = nn.CrossEntropyLoss(reduction='mean') | |
test_criterion = nn.CrossEntropyLoss(reduction='sum') | |
# Data for Loss Plots | |
train_loss_list = [] | |
test_loss_list = [] | |
test_acc_list = [] | |
# Training/Testing Functions | |
def train(model, device, train_loader, optimizer, epoch, log_interval, lr_scheduler): | |
model.train() | |
total_loss = 0.0 | |
for batch_idx, (data, target) in enumerate(train_loader): | |
data, target = data.to(device), target.to(device) | |
optimizer.zero_grad() | |
output = model(data) | |
loss = train_criterion(output, target) | |
loss.backward() | |
optimizer.step() | |
total_loss += loss.item() | |
if batch_idx % log_interval == 0: | |
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( | |
epoch, batch_idx * len(data), len(train_loader.dataset), | |
100. * batch_idx / len(train_loader), loss.item())) | |
total_loss /= len(train_loader.dataset) | |
train_loss_list.append(total_loss) | |
lr_scheduler.step() | |
def test(model, device, test_loader): | |
model.eval() | |
test_loss = 0 | |
correct = 0 | |
with torch.no_grad(): | |
for data, target in test_loader: | |
data, target = data.to(device), target.to(device) | |
output = model(data) | |
test_loss += test_criterion(output, target).item() | |
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability | |
correct += pred.eq(target.view_as(pred)).sum().item() | |
test_loss /= len(test_loader.dataset) | |
test_loss_list.append(test_loss) | |
test_acc = 100. * correct / len(test_loader.dataset) | |
test_acc_list.append(test_acc) | |
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( | |
test_loss, correct, len(test_loader.dataset), | |
test_acc)) | |
# CIFAR10 Data Loaders | |
transform_train = transforms.Compose([ | |
transforms.RandomCrop(32, padding=4), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | |
]) | |
transform_test = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | |
]) | |
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) | |
train_loader = torch.utils.data.DataLoader(trainset, batch_size=training_batch_size, shuffle=True, num_workers=2) | |
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) | |
test_loader = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, shuffle=False, num_workers=2) | |
# Model and Optimizer | |
model = HyperNN(hardcoded_input_size, hardcoded_n_classes, embed_size, embedW_init_scheme, | |
hyperWi_init_scheme, hyperWout_init_scheme, device).to(device) | |
num_params = sum([param.numel() for param in model.parameters()]) | |
print("number of params:", num_params) | |
optimizer = optim.SGD(model.parameters(), lr=lr) | |
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[350,450], gamma=0.1) | |
# Actual HyperNet Training | |
for epoch in range(1, epochs + 1): | |
train(model, device, train_loader, optimizer, epoch, log_interval, lr_scheduler) | |
test(model, device, test_loader) | |
# Save Experiment | |
result_dict = {'train_loss_list': np.array(train_loss_list), | |
'test_loss_list': np.array(test_loss_list), | |
'test_acc_list': np.array(test_acc_list) | |
} | |
torch.save(result_dict, 'results.dict') |
Hi,
I have been interested about the hypernetwork. The work you and you team have done is useful. However there is a line in your code that is difficult for me to understand even if I have read your paper several times.
My question is :
What does the 'hardcoded_receptive' means?
It will be helpful for me if you can give me a hint!!
Thanks!
Sorry for the late reply. We hardcode values depending on the sizes of the layers in the All Convolutional Net. In general, it might be difficult to have a hypernet initialization abstract enough that it can work off-the-shelf for different kinds of architectures in both the mainnet and the hypernet. We recommend adopting the general principle of preserving variance through the mainnet, and applying it to your specific neural network architecture.
Can you give an example for a FeedForward network?
A simple one hidden layer network.
Your paper is very difficult to understand.
Can you publish a version without hardcoded values?
Thanks!