Last active
April 6, 2023 08:31
-
-
Save justheuristic/9e4fb81381451a4bc8cbfee0a5100eba to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This is free and unencumbered software released into the public domain. | |
Anyone is free to copy, modify, publish, use, compile, sell, or | |
distribute this software, either in source code form or as a compiled | |
binary, for any purpose, commercial or non-commercial, and by any | |
means. | |
In jurisdictions that recognize copyright laws, the author or authors | |
of this software dedicate any and all copyright interest in the | |
software to the public domain. We make this dedication for the benefit | |
of the public at large and to the detriment of our heirs and | |
successors. We intend this dedication to be an overt act of | |
relinquishment in perpetuity of all present and future rights to this | |
software under copyright law. | |
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, | |
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF | |
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. | |
IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR | |
OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, | |
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR | |
OTHER DEALINGS IN THE SOFTWARE. | |
For more information, please refer to <http://unlicense.org/> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import functools | |
from typing import Optional | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import einops | |
class PixelflyLinear(nn.Module): | |
def __init__( | |
self, in_features: int, out_features: int, | |
lowrank_size: int, block_size: int, butterfly_size: int, | |
n_factors: Optional[int] = None, stretch: bool = True, bias: bool = True, | |
): | |
super().__init__() | |
self.out_features, self.in_features = out_features, in_features | |
self.register_buffer("butterfly_flat_indices", get_butterfly_indices( | |
out_features, in_features, block_size, butterfly_size, n_factors, stretch)) | |
self.lowrank_first = nn.Linear(in_features, lowrank_size, bias=False) | |
self.lowrank_second = nn.Linear(lowrank_size, out_features, bias=bias) | |
active_blocks_per_input = self.butterfly_flat_indices.numel() // (in_features // block_size) | |
self.weight = nn.Parameter(torch.empty(in_features, active_blocks_per_input, block_size)) | |
nn.init.xavier_normal_(self.weight) | |
def forward(self, input): | |
output = self.lowrank_second(self.lowrank_first(input)) | |
output += butterfly_matmul(input, self.weight, self.butterfly_flat_indices) | |
return output | |
@functools.lru_cache | |
def get_butterfly_indices( | |
out_features: int, | |
in_features: int, | |
block_size: int = 256, | |
butterfly_size: int = 64, | |
n_factors: Optional[int] = None, | |
stretch: bool = False, | |
) -> torch.IntTensor: | |
""" | |
Get a matrix [num_output_blocks, num_active_input_blocks] with int32 indices for additive butterfly | |
Based on the original implementation from https://arxiv.org/abs/2112.00029 . | |
:param stretch: by default, non-square matrices will have stretched butterfly patterns, | |
otherwise the square pattern will be repeated a given number of times | |
""" | |
assert ( | |
out_features % in_features == 0 or in_features % out_features == 0 | |
), "if matrix is not square, the longer dimension must be a multiple of the shorter dimension" | |
assert out_features % block_size == 0 and in_features % block_size == 0 | |
log_n = int(math.log2(butterfly_size)) | |
n_factors = log_n if n_factors is None else n_factors | |
if butterfly_size != 2 ** log_n or butterfly_size < 2: | |
raise NotImplementedError("butterfly_size must be a power of 2") | |
if not (1 <= n_factors <= log_n): | |
raise NotImplementedError( | |
"n_factors must be a between 1 and log_2(butterfly_size)" | |
) | |
twiddle = torch.ones(butterfly_size // 2, 2, 2) | |
layout = sum( | |
butterfly_factor_to_matrix(twiddle, index) for index in range(n_factors) | |
) | |
layout = layout.bool().int() | |
# Convert from (butterfly_size, butterfly_size) mask to (out_features, in_features) mask | |
layout = einops.repeat( | |
layout, | |
"b b1 -> (b f) (b1 f1)", | |
f=out_features // butterfly_size, | |
f1=in_features // butterfly_size, | |
) | |
# Convert from (out_features, in_features) mask to | |
# (out_features // block_size, in_features // block_size) mask | |
layout = einops.rearrange( | |
layout, | |
"(p blksz) (r blksz1) -> p r (blksz blksz1)", | |
blksz=block_size, | |
blksz1=block_size, | |
) | |
layout = (layout > 0).any( | |
dim=-1 | |
) # [out_features // block_size, in_features // block_size] | |
if not stretch: | |
out_blocks, in_blocks = layout.shape | |
if out_blocks > in_blocks: | |
ratio = out_blocks // in_blocks | |
layout = ( | |
layout.view(out_blocks // ratio, ratio, in_blocks) | |
.permute(1, 0, 2) | |
.reshape_as(layout) | |
) | |
elif out_blocks < in_blocks: | |
ratio = in_blocks // out_blocks | |
layout = ( | |
layout.view(out_blocks, in_blocks // ratio, ratio) | |
.permute(0, 2, 1) | |
.reshape_as(layout) | |
) | |
# convert boolean layout to indices for F.embedding_bag | |
num_output_blocks = out_features // block_size | |
num_input_blocks = in_features // block_size | |
active_blocks_per_output = layout.sum(1).unique() | |
assert ( | |
len(active_blocks_per_output) == 1 | |
), "butterfly layout must have the same number of blocks per row" | |
active_blocks_per_output = active_blocks_per_output.item() | |
active_blocks_per_input = layout.sum(0).unique() | |
assert ( | |
len(active_blocks_per_input) == 1 | |
), "butterfly layout must have the same number of blocks per row" | |
active_blocks_per_input = active_blocks_per_input.item() | |
# which input blocks should be added for i-th output | |
input_block_index = layout.nonzero()[:, 1].view( | |
num_output_blocks, active_blocks_per_output | |
) | |
# which output blocks does j-th input contribute to | |
output_block_index = ( | |
layout.t().nonzero()[:, 1].view(num_input_blocks, active_blocks_per_input) | |
) | |
# which of the active blocks from the corresponding input_block should be used for i-th output | |
active_block_index = torch.where( | |
torch.eq( | |
output_block_index[input_block_index], | |
torch.arange(len(input_block_index))[:, None, None], | |
) | |
)[-1].view(input_block_index.shape) | |
return input_block_index * active_blocks_per_input + active_block_index | |
def butterfly_factor_to_matrix( | |
twiddle: torch.Tensor, factor_index: int | |
) -> torch.Tensor: | |
""" | |
Let b be the base (most commonly 2). | |
Parameters: | |
twiddle: (n // b, b, b) | |
factor_index: an int from 0 to log_b(n) - 1 | |
""" | |
n_div_b, b, _ = twiddle.shape | |
n = b * n_div_b | |
log_b_n = int(math.log(n) / math.log(b)) | |
assert n == b ** log_b_n, f"n must be a power of {b}" | |
assert twiddle.shape == (n // b, b, b) | |
assert 0 <= factor_index <= log_b_n | |
stride = b ** factor_index | |
x = einops.rearrange( | |
torch.eye(n), "bs (diagblk j stride) -> bs diagblk j stride", stride=stride, j=b | |
) | |
t = einops.rearrange( | |
twiddle, "(diagblk stride) i j -> diagblk stride i j", stride=stride | |
) | |
out = torch.einsum("d s i j, b d j s -> b d i s", t, x) | |
out = einops.rearrange(out, "b diagblk i stride -> b (diagblk i stride)") | |
return ( | |
out.t() | |
) # Transpose because we assume the 1st dimension of x is the batch dimension | |
def butterfly_matmul( | |
input: torch.Tensor, weight: torch.Tensor, butterfly_flat_indices: torch.Tensor | |
): | |
""" | |
:param input: tensor [*batch_dims, in_features] | |
:param weight: tensor [in_features, active_blocks_per_input, block_size] | |
:param butterfly_flat_indices: outputs of get_butterfly_indices(...) | |
:returns: tensor [*batch_dims, out_features] | |
""" | |
assert input.shape[-1] == weight.shape[0] | |
in_features, active_blocks_per_input, block_size = weight.shape | |
num_input_blocks = in_features // block_size | |
batch_dims = input.shape[:-1] | |
input = input.flatten(0, -2) | |
input_permuted = input.t().view( | |
input.shape[1] // block_size, block_size, input.shape[0] | |
) | |
output_blocks = torch.bmm( | |
weight.view(num_input_blocks, -1, block_size), input_permuted | |
) | |
# ^-- shape: [num_input_blocks, (active_blocks_per_input * block_size), flat_batch_dims] | |
blocks_for_indexing = output_blocks.view( | |
num_input_blocks * active_blocks_per_input, block_size * input.shape[0] | |
) | |
# ^-- shape: [(num_input_blocks * active_blocks_per_input), (block_size, flat_batch_dims)] | |
aggregated_blocks = F.embedding_bag( | |
butterfly_flat_indices, blocks_for_indexing, mode="sum" | |
) | |
# ^-- shape: [num_ouput_blocks, (block_size, flat_batch_dims)] | |
outputs = aggregated_blocks.view(-1, input.shape[0]).t() | |
# ^-- shape: [flat_batch_dims, (num_output_blocks * block_size)] aka [flat_batch_dims, out_features] | |
return outputs.view(*batch_dims, outputs.shape[-1]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Could you try it with [(192, 768), (768, 192)] MLP (transformers in BERT and vision usually have 768/1024 hidden size)? From my observations, any MLPs smaller than 2048 would not see speed improvements.