Skip to content

Instantly share code, notes, and snippets.

@grey-area
Created January 25, 2022 12:21
Show Gist options
  • Save grey-area/9d090cf4e00457337e6af6f88e8ffd6a to your computer and use it in GitHub Desktop.
Save grey-area/9d090cf4e00457337e6af6f88e8ffd6a to your computer and use it in GitHub Desktop.
from einops import repeat, parse_shape
def make_broadcastable(array_list, pattern_list, result_pattern):
shape_list = []
all_keys = set()
for array, pattern in zip(array_list, pattern_list):
shape_list.append(parse_shape(array, pattern))
all_keys |= set(shape_list[-1].keys())
results = []
for array, pattern, shape in zip(array_list, pattern_list, shape_list):
shape.update({k: 1 for k in all_keys - set(shape.keys())})
results.append(repeat(array, f'{pattern} -> {result_pattern}', **shape))
return results
if __name__ == "__main__":
import numpy as np
xs = np.random.random(size=(2, 3))
ys = np.random.random(size=(2, 5))
zs = np.random.random(size=(2, 7))
xs, ys, zs = make_broadcastable([xs, ys, zs], ['n x', 'n y', 'n z'], 'n x y z')
print(xs.shape) # (2, 3, 1, 1)
print(ys.shape) # (2, 1, 5, 1)
print(zs.shape) # (2, 1, 1, 7)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment