Skip to content

Instantly share code, notes, and snippets.

@norabelrose
Last active October 22, 2023 05:18
Show Gist options
  • Save norabelrose/861c1c839fe72fb98b32d149cd3c3bac to your computer and use it in GitHub Desktop.
Save norabelrose/861c1c839fe72fb98b32d149cd3c3bac to your computer and use it in GitHub Desktop.
Blocked moment generator
from itertools import (
combinations_with_replacement as pyramid
)
from typing import Iterable
import math
from opt_einsum import get_symbol
from torch import Tensor
import torch
def moments(X: Tensor, ord: int, block_size: int = 1) -> Iterable[Tensor]:
assert X.ndim == 2, "X must be a matrix"
n, d = X.shape
num_blocks, rem = divmod(d, block_size)
assert rem == 0, f"block_size must divide d, got {d} % {block_size} = {rem}"
# Construct the einsum string
symbols = list(map(get_symbol, range(1, ord + 1)))
lhs = ",".join(f"{get_symbol(0)}{sym}" for sym in symbols)
einsum_str = f"{lhs}->" + "".join(symbols)
blocked_X = X.unflatten(1, (num_blocks, block_size))
for indices in pyramid(range(num_blocks), ord):
blocks = [
blocked_X[:, idx, :]
for idx in indices
]
yield torch.einsum(einsum_str, *blocks) / n
def condense_symmetric(X: Tensor, block_size: int) -> Tensor:
"""Condense a symmetric tensor into a vector"""
d, *rest = X.shape
assert all(dim == d for dim in rest), "X must be a symmetric tensor"
k = X.ndim
num_blocks = d // block_size
numel = math.comb(num_blocks + k - 1, k)
hypercube = [block_size] * k
buffer = X.new_empty(numel, *hypercube)
for i, indices in enumerate(pyramid(range(num_blocks), k)):
slices = tuple(
slice(idx * block_size, (idx + 1) * block_size)
for idx in indices
)
buffer[i] = X[slices]
return buffer
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment