Created
April 10, 2024 22:13
-
-
Save gngdb/6ff17112942f4f12d1af18e282de3470 to your computer and use it in GitHub Desktop.
einsum implemented with `itertools.product`
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import itertools | |
from collections import OrderedDict | |
def einsum_itertools(equation, *operands, verbose=False): | |
# Parse the equation | |
input_labels, output_labels = equation.split('->') | |
input_labels = input_labels.split(',') | |
if verbose: | |
print(f"{input_labels=} {output_labels=}") | |
# Get the dimensions of each operand | |
input_dims = [list(op.shape) for op in operands] | |
if verbose: | |
print(f"{input_dims=}") | |
# Create a dictionary mapping labels to dimensions | |
label_dims = OrderedDict({}) | |
for labels, dims in zip(input_labels, input_dims): | |
for label, dim in zip(labels, dims): | |
label_dims[label] = dim | |
if verbose: | |
print(f"{label_dims=}") | |
# Compute the output shape | |
output_shape = [label_dims[label] for label in output_labels] | |
if verbose: | |
print(f"{output_shape=}") | |
# Create the output tensor | |
output = torch.zeros(output_shape) | |
# Generate the indices for iteration | |
indices = [range(dim) for dim in label_dims.values()] | |
if verbose: | |
print(f"{[len(i) for i in indices]=}") | |
# Perform the einsum operation using nested iteration | |
for idx in itertools.product(*indices): | |
if verbose: | |
print(f" {idx=}") | |
# Create a dictionary mapping labels to indices | |
label_idx = OrderedDict(zip(label_dims.keys(), idx)) | |
if verbose: | |
print(f" {label_idx=}") | |
# Compute the product of the operands at the current indices | |
product = 1 | |
for op, labels in zip(operands, input_labels): | |
op_idx = tuple(label_idx[label] for label in labels) | |
product *= op[op_idx] | |
if verbose: | |
print(f" {product=}") | |
# Update the output tensor at the corresponding indices | |
output_idx = tuple(label_idx[label] for label in output_labels) | |
if verbose: | |
print(f" {output_idx=}") | |
output[output_idx] += product | |
return output |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment