Last active
June 8, 2023 17:29
-
-
Save V0XNIHILI/5d4dd7e12c712122912462e1a1d97554 to your computer and use it in GitHub Desktop.
Initial take at variable size linear output layer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn as nn | |
import torch | |
import torch.nn.functional as F | |
class Net(nn.Module): | |
def __init__(self, input_size=16, hidden_size=32, initial_output_size=5): | |
super().__init__() | |
self.input_size = input_size | |
self.hidden_size = hidden_size | |
self.initial_output_size = initial_output_size | |
self.embedder = nn.Linear(self.input_size, self.hidden_size) | |
self.linear = nn.Linear(self.hidden_size, initial_output_size) | |
def forward(self, x: torch.Tensor): | |
x = self.embedder(x) | |
x = self.linear(x) | |
return x | |
def expand_linear_layer(linear_layer: nn.Linear, output_size_increment: int): | |
"""Expand the output size of a linear layer. | |
Args: | |
linear_layer (nn.Linear): Linear layer to be expanded. | |
output_size_increment (int): Increment of output size. | |
""" | |
input_size = linear_layer.in_features | |
output_size = linear_layer.out_features | |
new_output_size = output_size + output_size_increment | |
new_weight = torch.randn(new_output_size, input_size).to(linear_layer.weight.device) | |
new_bias = torch.randn(new_output_size).to(linear_layer.bias.device) | |
new_weight[:output_size] = linear_layer.weight.data | |
new_bias[:output_size] = linear_layer.bias.data | |
linear_layer.in_features = input_size | |
linear_layer.out_features = new_output_size | |
linear_layer.weight= torch.nn.Parameter(new_weight) | |
linear_layer.bias = torch.nn.Parameter(new_bias) | |
def compute_restricted_outputs(outputs: torch.Tensor, output_size_increment: int): | |
"""Mask out all but the last output_size_increment number of outputs. | |
Args: | |
outputs (torch.Tensor): Output tensor from a linear layer. | |
output_size_increment (int): Number of outputs to keep. | |
Returns: | |
torch.Tensor: Masked output tensor. | |
""" | |
num_classes = outputs.size(1) | |
restricted_mask = torch.zeros(num_classes) | |
restricted_mask[-output_size_increment:] = 1.0 | |
restricted_mask = restricted_mask.to(outputs.device) | |
# Apply the restricted mask to the outputs | |
restricted_outputs = outputs * restricted_mask | |
return restricted_outputs | |
# ------------------------------------------------------------------------------- | |
batch_size = 32 | |
initial_linear_layer_size = 5 | |
output_size_increment = 5 | |
steps = 4 | |
input_size = 16 | |
hidden_size = 4 | |
net = Net(input_size, hidden_size, initial_linear_layer_size) | |
criterion = nn.CrossEntropyLoss() | |
for step in range(steps): | |
opt = torch.optim.Adam(net.parameters(), lr=0.01) | |
inputs = torch.randn(batch_size, input_size) | |
targets = torch.randint(0, output_size_increment, (batch_size,)) | |
outputs = net(inputs) | |
outputs = compute_restricted_outputs(outputs, output_size_increment) | |
loss = criterion(outputs, targets + net.linear.weight.data.shape[0] - output_size_increment) | |
loss.backward() | |
opt.step() | |
if step != steps - 1: | |
expand_linear_layer(net.linear, output_size_increment) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment