Last active
April 14, 2022 19:19
-
-
Save justheuristic/499ec116f1f353dfd3314de87f310f80 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This is free and unencumbered software released into the public domain. | |
Anyone is free to copy, modify, publish, use, compile, sell, or | |
distribute this software, either in source code form or as a compiled | |
binary, for any purpose, commercial or non-commercial, and by any | |
means. | |
In jurisdictions that recognize copyright laws, the author or authors | |
of this software dedicate any and all copyright interest in the | |
software to the public domain. We make this dedication for the benefit | |
of the public at large and to the detriment of our heirs and | |
successors. We intend this dedication to be an overt act of | |
relinquishment in perpetuity of all present and future rights to this | |
software under copyright law. | |
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, | |
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF | |
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. | |
IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR | |
OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, | |
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR | |
OTHER DEALINGS IN THE SOFTWARE. | |
For more information, please refer to <http://unlicense.org/> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from torch.utils.checkpoint import checkpoint | |
import numpy as np | |
from functools import partial | |
from typing import Sequence | |
class MonarchLinear(nn.Module): | |
def __init__(self, in_features: int, out_features: int, | |
in_dims: Sequence[int], out_dims: Sequence[int], | |
bias: bool = True, checkpoint: bool = False, | |
): | |
""" | |
Monarch linear layer, a generalization of https://arxiv.org/abs/2204.00595 | |
Ths implementation interprets Monarch as a product over an M by M grid (in_features=M ^ 2). | |
The first product applies over all rows of the grid, the second runs over columns. | |
In general, the grid may have uneven size or more than 2 dimensions. | |
In the 2d case, the two products use [M x M x M] weight tensors. In the general case, | |
it uses grid_dim weight tensors of shape [grid_numel / in_dims[i], in_dims[i], out_dims[i]]. | |
:param in_features: input dimension, same as in nn.Linear | |
:param out_features: output dimension, same as in nn.Linear | |
:param in_dims: a tuple of numbers that multiply to in_features, see example below | |
:param out_dims: a tuple of numbers that multiply to out_features, see example below | |
:param bias: whether or not to use a bias term, same as in nn.Linear | |
:param checkpoint: if True, apply gradient checkpointing over this entire layer. | |
This adds ~30% compute overhead for forward+backward, but reduces the memory overhead; | |
otherwise, monarch must to store ndim - 1 additional tensors for intermediate activations. | |
:example: | |
>>> # classic monarch: | |
>>> MonarchLinear(in_features=1024, in_dims=(32, 32), out_features=1024, out_dims=(32, 32)) | |
>>> # generalization to rectangular matrices | |
>>> MonarchLinear(in_features=1024, in_dims=(32, 32), out_features=4096, out_dims=(64, 64)) | |
>>> MonarchLinear(in_features=1024, in_dims=(32, 32), out_features=1536, out_dims=(32, 48)) | |
>>> # generalization to higher dimension | |
>>> MonarchLinear(in_features=4096, in_dims=(16, 16, 16), out_features=4096, out_dims=(16, 16, 16)) | |
>>> MonarchLinear(in_features=4096, in_dims=(16, 16, 16), out_features=1536, out_dims=(8, 12, 16)) | |
""" | |
super().__init__() | |
assert len(in_dims) == len(out_dims) and len(in_dims) > 1 | |
assert np.prod(in_dims) == in_features | |
assert np.prod(out_dims) == out_features | |
self.in_features, self.out_features = in_features, out_features | |
self.in_dims, self.out_dims = in_dims, out_dims | |
self.checkpoint = checkpoint | |
# construct weight tensors by keeping track of intermediate tensor dimension at each step | |
self.weights = nn.ParameterList() | |
current_numel = np.prod(in_dims) | |
assert current_numel == in_features | |
for i, (in_dim, out_dim) in enumerate(zip(in_dims, out_dims)): | |
self.weights.append(nn.Parameter(torch.empty(current_numel // in_dim, in_dim, out_dim))) | |
current_numel = current_numel // in_dim * out_dim | |
assert current_numel == out_features | |
self.register_parameter('bias', nn.Parameter(torch.empty(out_features)) if bias else None) | |
self.reset_parameters() | |
def reset_parameters(self, gain: float = 1.0): | |
# initialize, re-scale to account for the number of multiplied tensors | |
init_std = (gain / np.sqrt(self.in_features)) ** (1 / len(self.in_dims)) | |
for weight in self.weights: | |
nn.init.normal_(weight, std=init_std) | |
if self.bias is not None: | |
bound = 1 / np.sqrt(self.in_features) | |
nn.init.uniform_(self.bias, -bound, bound) | |
def forward(self, input: torch.Tensor, _inside_checkpoint: bool = False): | |
if self.checkpoint and not _inside_checkpoint and torch.is_grad_enabled(): | |
return checkpoint(partial(self.forward, _inside_checkpoint=True), | |
input if input.requires_grad else input.detach().requires_grad_(True), | |
preserve_rng_state=False) | |
input_shape = input.shape | |
tensor = input.view(-1, *self.in_dims) | |
# shape: [flat_batch_size, in_dim[0], ..., in_dim[N]] | |
del input | |
tensor = tensor.permute(*np.roll(range(len(self.in_dims) + 1), -2)) | |
# new shape: [in_dim[1], ..., in_dim[N - 1], flat_batch_size, in_dim[0]] | |
for i in range(len(self.weights)): | |
# loop maintains tensor in the following shape: [*all_dims_except_i, batch, dim[i]] | |
tensor = torch.bmm( | |
tensor.flatten(0, -3), self.weights[i] | |
).view(*tensor.shape[:-1], -1) | |
# ^-- BMM, output: [*other_dims, batch, out_dim[i]] | |
# left input: [*other_dims, batch, in_dim[i]] | |
# right_input: [*other_dims, in_dim[i], out_dim[i]] | |
# prepare next step, from [*other_dims, batch, out_dim[i]] to [*other_dims, batch, in_dim[i + 1]] | |
tensor = tensor.swapaxes_(-1, i) | |
# note: we can swap in-place because bmm does not need outputs for backprop | |
# after loop: [out_dim[0], ..., out_dim[N - 1], batch] | |
tensor = tensor.flatten(0, -2).swapaxes_(0, 1) | |
tensor = tensor.reshape(*input_shape[:-1], -1) | |
if self.bias is not None: | |
tensor += self.bias | |
return tensor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "48e61d12", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"A100-SXM4-40GB\n" | |
] | |
} | |
], | |
"source": [ | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"\n", | |
"from monarch import MonarchLinear # see monarch.py in the same gist\n", | |
"\n", | |
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", | |
"if torch.cuda.is_available():\n", | |
" print(torch.cuda.get_device_name(0))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "be65a46b", | |
"metadata": {}, | |
"source": [ | |
"__Baseline:__ pytorch linear with a full dense matrix" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "b3ca6eaf", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"baseline = nn.Sequential(nn.Linear(4096, 16384), nn.ReLU(), nn.Linear(16384, 4096)).to(device=device)\n", | |
"baseline_size = sum(p.numel() for p in baseline.parameters())\n", | |
"input = torch.randn(16, 512, 4096, device=device)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "0265fc13", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 3.57 s, sys: 1.48 s, total: 5.05 s\n", | |
"Wall time: 5.64 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"for i in range(100):\n", | |
" out = baseline(input)\n", | |
" out.sum().backward()\n", | |
" \n", | |
"torch.cuda.synchronize()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "edc898db", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"del baseline, out\n", | |
"torch.cuda.empty_cache()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "5b19be3f", | |
"metadata": {}, | |
"source": [ | |
"__Same, but with Monarch__" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "67b41f28", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"monarch = nn.Sequential(\n", | |
" MonarchLinear(4096, 16384, in_dims=(64, 64), out_dims=(64, 256)),\n", | |
" nn.ReLU(),\n", | |
" MonarchLinear(16384, 4096, in_dims=(256, 64), out_dims=(64, 64))\n", | |
").to(device)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "8abb1a1f", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 1.64 s, sys: 920 ms, total: 2.56 s\n", | |
"Wall time: 3.07 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"for i in range(100):\n", | |
" out = monarch(input)\n", | |
" out.sum().backward()\n", | |
" \n", | |
"torch.cuda.synchronize()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "ffb67dbb", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Compression rate (parameters): 0.01968083483355201\n" | |
] | |
} | |
], | |
"source": [ | |
"monarch_size = sum(p.numel() for p in monarch.parameters())\n", | |
"print(\"Compression rate (parameters):\", monarch_size / baseline_size)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "68051a51", | |
"metadata": {}, | |
"source": [ | |
"__Generalized 3D Monarch__" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "4a168c5d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"monarch = nn.Sequential(\n", | |
" MonarchLinear(4096, 16384, in_dims=(16, 16, 16), out_dims=(16, 16, 64)),\n", | |
" nn.ReLU(),\n", | |
" MonarchLinear(16384, 4096, in_dims=(64, 16, 16), out_dims=(16, 16, 16))\n", | |
").to(device)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "35b3fa94", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 1.94 s, sys: 958 ms, total: 2.9 s\n", | |
"Wall time: 2.9 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"for i in range(100):\n", | |
" out = monarch(input)\n", | |
" out.sum().backward()\n", | |
" \n", | |
"torch.cuda.synchronize()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "b303548d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Compression rate (parameters): 0.006011045677844567\n" | |
] | |
} | |
], | |
"source": [ | |
"monarch_size = sum(p.numel() for p in monarch.parameters())\n", | |
"print(\"Compression rate (parameters):\", monarch_size / baseline_size)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.10" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment