Skip to content

Instantly share code, notes, and snippets.

@justheuristic
Last active April 6, 2023 08:31
Show Gist options
  • Save justheuristic/9e4fb81381451a4bc8cbfee0a5100eba to your computer and use it in GitHub Desktop.
Save justheuristic/9e4fb81381451a4bc8cbfee0a5100eba to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This is free and unencumbered software released into the public domain.
Anyone is free to copy, modify, publish, use, compile, sell, or
distribute this software, either in source code form or as a compiled
binary, for any purpose, commercial or non-commercial, and by any
means.
In jurisdictions that recognize copyright laws, the author or authors
of this software dedicate any and all copyright interest in the
software to the public domain. We make this dedication for the benefit
of the public at large and to the detriment of our heirs and
successors. We intend this dedication to be an overt act of
relinquishment in perpetuity of all present and future rights to this
software under copyright law.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
OTHER DEALINGS IN THE SOFTWARE.
For more information, please refer to <http://unlicense.org/>
import math
import functools
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import einops
class PixelflyLinear(nn.Module):
def __init__(
self, in_features: int, out_features: int,
lowrank_size: int, block_size: int, butterfly_size: int,
n_factors: Optional[int] = None, stretch: bool = True, bias: bool = True,
):
super().__init__()
self.out_features, self.in_features = out_features, in_features
self.register_buffer("butterfly_flat_indices", get_butterfly_indices(
out_features, in_features, block_size, butterfly_size, n_factors, stretch))
self.lowrank_first = nn.Linear(in_features, lowrank_size, bias=False)
self.lowrank_second = nn.Linear(lowrank_size, out_features, bias=bias)
active_blocks_per_input = self.butterfly_flat_indices.numel() // (in_features // block_size)
self.weight = nn.Parameter(torch.empty(in_features, active_blocks_per_input, block_size))
nn.init.xavier_normal_(self.weight)
def forward(self, input):
output = self.lowrank_second(self.lowrank_first(input))
output += butterfly_matmul(input, self.weight, self.butterfly_flat_indices)
return output
@functools.lru_cache
def get_butterfly_indices(
out_features: int,
in_features: int,
block_size: int = 256,
butterfly_size: int = 64,
n_factors: Optional[int] = None,
stretch: bool = False,
) -> torch.IntTensor:
"""
Get a matrix [num_output_blocks, num_active_input_blocks] with int32 indices for additive butterfly
Based on the original implementation from https://arxiv.org/abs/2112.00029 .
:param stretch: by default, non-square matrices will have stretched butterfly patterns,
otherwise the square pattern will be repeated a given number of times
"""
assert (
out_features % in_features == 0 or in_features % out_features == 0
), "if matrix is not square, the longer dimension must be a multiple of the shorter dimension"
assert out_features % block_size == 0 and in_features % block_size == 0
log_n = int(math.log2(butterfly_size))
n_factors = log_n if n_factors is None else n_factors
if butterfly_size != 2 ** log_n or butterfly_size < 2:
raise NotImplementedError("butterfly_size must be a power of 2")
if not (1 <= n_factors <= log_n):
raise NotImplementedError(
"n_factors must be a between 1 and log_2(butterfly_size)"
)
twiddle = torch.ones(butterfly_size // 2, 2, 2)
layout = sum(
butterfly_factor_to_matrix(twiddle, index) for index in range(n_factors)
)
layout = layout.bool().int()
# Convert from (butterfly_size, butterfly_size) mask to (out_features, in_features) mask
layout = einops.repeat(
layout,
"b b1 -> (b f) (b1 f1)",
f=out_features // butterfly_size,
f1=in_features // butterfly_size,
)
# Convert from (out_features, in_features) mask to
# (out_features // block_size, in_features // block_size) mask
layout = einops.rearrange(
layout,
"(p blksz) (r blksz1) -> p r (blksz blksz1)",
blksz=block_size,
blksz1=block_size,
)
layout = (layout > 0).any(
dim=-1
) # [out_features // block_size, in_features // block_size]
if not stretch:
out_blocks, in_blocks = layout.shape
if out_blocks > in_blocks:
ratio = out_blocks // in_blocks
layout = (
layout.view(out_blocks // ratio, ratio, in_blocks)
.permute(1, 0, 2)
.reshape_as(layout)
)
elif out_blocks < in_blocks:
ratio = in_blocks // out_blocks
layout = (
layout.view(out_blocks, in_blocks // ratio, ratio)
.permute(0, 2, 1)
.reshape_as(layout)
)
# convert boolean layout to indices for F.embedding_bag
num_output_blocks = out_features // block_size
num_input_blocks = in_features // block_size
active_blocks_per_output = layout.sum(1).unique()
assert (
len(active_blocks_per_output) == 1
), "butterfly layout must have the same number of blocks per row"
active_blocks_per_output = active_blocks_per_output.item()
active_blocks_per_input = layout.sum(0).unique()
assert (
len(active_blocks_per_input) == 1
), "butterfly layout must have the same number of blocks per row"
active_blocks_per_input = active_blocks_per_input.item()
# which input blocks should be added for i-th output
input_block_index = layout.nonzero()[:, 1].view(
num_output_blocks, active_blocks_per_output
)
# which output blocks does j-th input contribute to
output_block_index = (
layout.t().nonzero()[:, 1].view(num_input_blocks, active_blocks_per_input)
)
# which of the active blocks from the corresponding input_block should be used for i-th output
active_block_index = torch.where(
torch.eq(
output_block_index[input_block_index],
torch.arange(len(input_block_index))[:, None, None],
)
)[-1].view(input_block_index.shape)
return input_block_index * active_blocks_per_input + active_block_index
def butterfly_factor_to_matrix(
twiddle: torch.Tensor, factor_index: int
) -> torch.Tensor:
"""
Let b be the base (most commonly 2).
Parameters:
twiddle: (n // b, b, b)
factor_index: an int from 0 to log_b(n) - 1
"""
n_div_b, b, _ = twiddle.shape
n = b * n_div_b
log_b_n = int(math.log(n) / math.log(b))
assert n == b ** log_b_n, f"n must be a power of {b}"
assert twiddle.shape == (n // b, b, b)
assert 0 <= factor_index <= log_b_n
stride = b ** factor_index
x = einops.rearrange(
torch.eye(n), "bs (diagblk j stride) -> bs diagblk j stride", stride=stride, j=b
)
t = einops.rearrange(
twiddle, "(diagblk stride) i j -> diagblk stride i j", stride=stride
)
out = torch.einsum("d s i j, b d j s -> b d i s", t, x)
out = einops.rearrange(out, "b diagblk i stride -> b (diagblk i stride)")
return (
out.t()
) # Transpose because we assume the 1st dimension of x is the batch dimension
def butterfly_matmul(
input: torch.Tensor, weight: torch.Tensor, butterfly_flat_indices: torch.Tensor
):
"""
:param input: tensor [*batch_dims, in_features]
:param weight: tensor [in_features, active_blocks_per_input, block_size]
:param butterfly_flat_indices: outputs of get_butterfly_indices(...)
:returns: tensor [*batch_dims, out_features]
"""
assert input.shape[-1] == weight.shape[0]
in_features, active_blocks_per_input, block_size = weight.shape
num_input_blocks = in_features // block_size
batch_dims = input.shape[:-1]
input = input.flatten(0, -2)
input_permuted = input.t().view(
input.shape[1] // block_size, block_size, input.shape[0]
)
output_blocks = torch.bmm(
weight.view(num_input_blocks, -1, block_size), input_permuted
)
# ^-- shape: [num_input_blocks, (active_blocks_per_input * block_size), flat_batch_dims]
blocks_for_indexing = output_blocks.view(
num_input_blocks * active_blocks_per_input, block_size * input.shape[0]
)
# ^-- shape: [(num_input_blocks * active_blocks_per_input), (block_size, flat_batch_dims)]
aggregated_blocks = F.embedding_bag(
butterfly_flat_indices, blocks_for_indexing, mode="sum"
)
# ^-- shape: [num_ouput_blocks, (block_size, flat_batch_dims)]
outputs = aggregated_blocks.view(-1, input.shape[0]).t()
# ^-- shape: [flat_batch_dims, (num_output_blocks * block_size)] aka [flat_batch_dims, out_features]
return outputs.view(*batch_dims, outputs.shape[-1])
@JunweiLiang
Copy link

Could you try it with [(192, 768), (768, 192)] MLP (transformers in BERT and vision usually have 768/1024 hidden size)? From my observations, any MLPs smaller than 2048 would not see speed improvements.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment