Created
July 11, 2022 10:14
-
-
Save nilsleh/b6d8aeeb20f3b56d58dda762857e5283 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
import torch.nn as nn | |
from functorch import vmap, jacrev, make_functional_with_buffers | |
batch_size = 2 | |
in_channels = 5 | |
out_channels = 20 | |
feature_shape = 8 | |
feature = torch.rand(batch_size, in_channels, feature_shape, feature_shape) | |
class ConvBlock(nn.Module): | |
def __init__(self, in_ch, out_ch): | |
super(ConvBlock, self).__init__() | |
self.conv = nn.Sequential( | |
nn.Conv2d(in_ch, in_ch, 3, 1), | |
nn.Conv2d(in_ch, in_ch, 3, 1), | |
nn.Conv2d(in_ch, out_ch, 3, 1), | |
) | |
def forward(self, x): | |
return self.conv(x) | |
model = ConvBlock(in_channels, out_channels) | |
fmodel, params, buffers = make_functional_with_buffers(model) | |
def map_layer_param_to_flat_param(model): | |
tuple_param_lists = [] | |
num_params_so_far = 0 | |
for param_layer_idx, p in enumerate(model.parameters()): | |
vec = p.flatten() | |
for vec_idx in range(len(vec)): | |
tuple_param_lists.append([num_params_so_far + vec_idx, param_layer_idx, vec_idx]) | |
num_params_so_far += len(vec) | |
param_idx = torch.tensor(tuple_param_lists) | |
return param_idx | |
def define_subnet_and_other_indices(model): | |
param_vector = torch.cat([p.flatten() for p in model.parameters()], dim=0) | |
subnet_indices = torch.from_numpy( | |
np.sort(np.random.choice(np.arange(0, len(param_vector)), size=10)) | |
) | |
deterministic_indices = torch.tensor([k for k in range(len(param_vector)) if k not in subnet_indices]) | |
return subnet_indices, deterministic_indices | |
def split(params, relevant_indices, other_indices): | |
relevant_params = {} | |
other_params = {} | |
param_shapes = {} | |
for i, param in enumerate(params): | |
# gather the relevant parameters | |
relevant_param_idx_at_i = relevant_indices[relevant_indices[:,1]==i] | |
if relevant_param_idx_at_i.nelement() != 0: | |
relevant_idx_flat_param = relevant_param_idx_at_i[:,2] | |
relevant_params[i] = param.flatten()[relevant_idx_flat_param] | |
# gather the other parameters | |
other_param_idx_at_i = other_indices[other_indices[:,1]==i] | |
if other_param_idx_at_i.nelement() != 0: | |
other_idx_flat_param = other_param_idx_at_i[:,2] | |
other_params[i] = param.flatten()[other_idx_flat_param] | |
# keep track of shapes to reconstruct them later | |
param_shapes[i] = param.shape | |
return relevant_params, other_params, param_shapes | |
def combine(relevant_params, other_params, relevant_indices, other_indices, param_shapes): | |
"""Reconstruct convolutional weight tensors, from 1d tensors.""" | |
reconstructed_params = [] | |
for p_idx, p_shape in param_shapes.items(): | |
relevant_indices_at_p = relevant_indices[relevant_indices[:,1]==p_idx] | |
other_indices_at_p = other_indices[other_indices[:,1]==p_idx] | |
if (relevant_indices_at_p.nelement() != 0) & (other_indices_at_p.nelement() != 0): | |
all_indices_at_p = torch.cat([relevant_indices_at_p, other_indices_at_p], dim=0) | |
# argsort in the correct order of original flattened param_vector | |
sorted_indices_at_p = torch.argsort(all_indices_at_p, dim=0) | |
all_params_at_p = torch.cat([relevant_params[p_idx], other_params[p_idx]])[sorted_indices_at_p[:,2]] | |
elif (relevant_indices_at_p.nelement() != 0) & (other_indices_at_p.nelement() == 0): | |
all_params_at_p = relevant_params[p_idx] | |
elif (relevant_indices_at_p.nelement() == 0) & (other_indices_at_p.nelement() != 0): | |
all_params_at_p = other_params[p_idx] | |
reconstructed_params.append(all_params_at_p.view(p_shape)) | |
return tuple(reconstructed_params) | |
def compute_output_stateless_model(relevant_params, other_params, relevant_indices, other_indices, param_shapes, buffers, feature): | |
params = combine(relevant_params, other_params, relevant_indices, other_indices, param_shapes) | |
batch = feature.unsqueeze(0) | |
output = fmodel(params, buffers, batch) | |
output = output.view(batch.shape[0], -1, 8) | |
return output | |
param_idx = map_layer_param_to_flat_param(model) | |
subnet_indices, deterministic_indices = define_subnet_and_other_indices(model) | |
relevant_indices = param_idx[subnet_indices] | |
other_indices = param_idx[deterministic_indices] | |
relevant_params, other_params, param_shapes = split(params, relevant_indices, other_indices) | |
ft_compute_grad = jacrev(compute_output_stateless_model) | |
ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, None, None, None, None, 0)) | |
ft_per_sample_grads = ft_compute_sample_grad(relevant_params, other_params, relevant_indices, other_indices, param_shapes, buffers, feature) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment