Last active
October 22, 2023 05:18
-
-
Save norabelrose/861c1c839fe72fb98b32d149cd3c3bac to your computer and use it in GitHub Desktop.
Blocked moment generator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from itertools import ( | |
combinations_with_replacement as pyramid | |
) | |
from typing import Iterable | |
import math | |
from opt_einsum import get_symbol | |
from torch import Tensor | |
import torch | |
def moments(X: Tensor, ord: int, block_size: int = 1) -> Iterable[Tensor]: | |
assert X.ndim == 2, "X must be a matrix" | |
n, d = X.shape | |
num_blocks, rem = divmod(d, block_size) | |
assert rem == 0, f"block_size must divide d, got {d} % {block_size} = {rem}" | |
# Construct the einsum string | |
symbols = list(map(get_symbol, range(1, ord + 1))) | |
lhs = ",".join(f"{get_symbol(0)}{sym}" for sym in symbols) | |
einsum_str = f"{lhs}->" + "".join(symbols) | |
blocked_X = X.unflatten(1, (num_blocks, block_size)) | |
for indices in pyramid(range(num_blocks), ord): | |
blocks = [ | |
blocked_X[:, idx, :] | |
for idx in indices | |
] | |
yield torch.einsum(einsum_str, *blocks) / n | |
def condense_symmetric(X: Tensor, block_size: int) -> Tensor: | |
"""Condense a symmetric tensor into a vector""" | |
d, *rest = X.shape | |
assert all(dim == d for dim in rest), "X must be a symmetric tensor" | |
k = X.ndim | |
num_blocks = d // block_size | |
numel = math.comb(num_blocks + k - 1, k) | |
hypercube = [block_size] * k | |
buffer = X.new_empty(numel, *hypercube) | |
for i, indices in enumerate(pyramid(range(num_blocks), k)): | |
slices = tuple( | |
slice(idx * block_size, (idx + 1) * block_size) | |
for idx in indices | |
) | |
buffer[i] = X[slices] | |
return buffer |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment