Created
September 21, 2018 16:59
-
-
Save CharlesJQuarra/c5e2e1682ef33d70b9e58d2316bb744d to your computer and use it in GitHub Desktop.
Attempt of a linear unit that supports splitting the parameter space in a grid of gradient checkpoint nodes. The issue right now is that when there is more than one segment, the `backward()` only updates the gradient for the last parameter
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.utils.checkpoint as checkpoint | |
def get_segments(total, max_length): | |
if total > max_length: | |
segments = (total // max_length) | |
else: | |
segments = 1 | |
return (segments-1)*[max_length] + [total - (segments-1)*max_length] | |
class GradCheckpoint_Linear(nn.Module): | |
def __init__(self, in_features, out_features, cpc_specs={}): | |
super(GradCheckpoint_Linear, self).__init__() | |
self.in_features = in_features | |
self.out_features = out_features | |
if 'in_max_segment' in cpc_specs: | |
in_max_segment = cpc_specs['in_max_segment'] | |
else: | |
in_max_segment = in_features | |
self.in_segment_lengths = get_segments(in_features, in_max_segment) | |
if 'out_max_segment' in cpc_specs: | |
out_max_segment = cpc_specs['out_max_segment'] | |
else: | |
out_max_segment = out_features | |
if 'initializer' in cpc_specs: | |
self.initializer = cpc_specs['initializer'] | |
else: | |
def get_init(w, h): | |
return torch.randn(w,h) | |
self.initializer = get_init | |
self.out_segment_lengths = get_segments(out_features, out_max_segment) | |
print("in_segment_lengths: {0}".format(self.in_segment_lengths)) | |
print("out_segment_lengths: {0}".format(self.out_segment_lengths)) | |
weight_parameters_ = [] | |
bias_parameters_ = [] | |
self.array_to_weight_param = -torch.ones(len(self.in_segment_lengths), len(self.out_segment_lengths), dtype=torch.int32) | |
for in_idx, in_s_length in enumerate(self.in_segment_lengths): | |
for out_idx, out_s_length in enumerate(self.out_segment_lengths): | |
param = nn.Parameter( self.initializer(out_s_length,in_s_length) ) | |
self.array_to_weight_param[in_idx,out_idx]=len(weight_parameters_) | |
weight_parameters_.append( param ) | |
self.weight_parameters = nn.ParameterList( weight_parameters_ ) | |
for out_s_length in self.out_segment_lengths: | |
bias_parameters_.append( nn.Parameter( self.initializer(1, out_s_length).view(out_s_length) ) ) | |
self.bias_parameters = nn.ParameterList( bias_parameters_ ) | |
def reset_parameters(self): | |
pass | |
def forward(self, inp): | |
unit_outs = [] | |
for out_idx, out_s_length in enumerate(self.out_segment_lengths): | |
bias_param = self.bias_parameters[out_idx] | |
in_offset = 0 | |
weight_outs = [] | |
for in_idx, in_s_length in enumerate(self.in_segment_lengths): | |
weight_param = self.weight_parameters[ self.array_to_weight_param[in_idx, out_idx] ] | |
def fwd_unit_segment(inp_): | |
return torch.mv(weight_param, inp_) | |
weight_out = checkpoint.checkpoint( fwd_unit_segment , inp[in_offset:in_offset+in_s_length] ) | |
in_offset += in_s_length | |
weight_outs.append(weight_out) | |
unit_outs.append( bias_param + sum(weight_outs) ) #(*1 we squeeze back the 1 after matmul) | |
result = torch.cat(unit_outs) | |
return result | |
def init_(w,h): | |
return torch.ones(w,h) | |
u_original = GradCheckpoint_Linear(6, 2, cpc_specs={'initializer': init_}) | |
u_split = GradCheckpoint_Linear(6, 2, cpc_specs={'in_max_segment': 2, 'out_max_segment': 1, 'initializer': init_}) | |
inp_0_orig = torch.ones(6, requires_grad=True) | |
inp_0_split = torch.ones(6, requires_grad=True) | |
u_original(inp_0_orig).sum().backward() | |
u_original.weight_parameters[0].grad #<--- looks good | |
u_split(inp_0_orig).sum().backward() | |
u_split.weight_parameters[0].grad #<--- grad is None | |
u_split.weight_parameters[5].grad #<--- only grad parameter that has values is the last one? |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment