Skip to content

Instantly share code, notes, and snippets.

@pervognsen
Last active July 26, 2020 05:05
Show Gist options
  • Save pervognsen/0dbe377a146fec65d3d09c07db40e53b to your computer and use it in GitHub Desktop.
Save pervognsen/0dbe377a146fec65d3d09c07db40e53b to your computer and use it in GitHub Desktop.
# Tensor product contractions with Einstein's summation notation. Examples:
# Matrix-matrix multiply is tensor(A, 'ij', B, 'jk')
# Matrix-vector multiply is tensor(A, 'ij', v, 'j')
# Matrix trace is tensor(A, 'ii')
# Matrix transpose is tensor(A, 'ji')
# Matrix diagonal is tensor('i', A, 'ii')
# Inner product is tensor(v, 'i', w, 'i')
# Outer product is tensor(v, 'i', w, 'j')
def tensor(*args, **kwargs):
result = ''
if isinstance(args[0], str):
result, args = '->' + args[0], args[1:]
assert len(args) % 2 == 0
return np.einsum(','.join(args[1::2]) + result, *args[::2], **kwargs)
# Generate a view of all shifts of a tensor with a given stencil.
def shifts(data, stencil, step=1):
if isinstance(stencil, int):
stencil = [stencil] * len(data.shape)
elif isinstance(stencil, np.ndarray):
stencil = stencil.shape
if isinstance(step, int):
step = [step] * len(data.shape)
assert len(stencil) == len(data.shape)
assert len(step) == len(data.shape)
strides = [k * stride for k, stride in zip(step, data.strides)] + list(data.strides)
shape = [max(0, (m - n + k) // k) for m, n, k in zip(data.shape, stencil, step)] + list(stencil)
return np_as_strided(data, strides=strides, shape=shape)
# 1-dimensional convolution
def convolve1(signal, kernel):
return tensor(shifts(signal, kernel), 'ij', kernel, 'j')
# 2-dimensional convolution
def convolve2(signal, kernel):
return tensor(shifts(signal, kernel), 'ijkl', kernel, 'kl')
# Generic n-dimensional convolution
def convolve(signal, kernel):
subscripts = string.ascii_lowercase
n = len(signal.shape)
return tensor(shifts(signal, kernel), subscripts[:n*2], kernel, subscripts[n:n*2])
print(convolve1(np.arange(10), np.array([1, 1]))) # [1 3 5 7 9 11 13 15 17]
print(convolve(np.arange(10), np.array([1, 1]))) # [1 3 5 7 9 11 13 15 17]
print(convolve(np.arange(10), np.array([-1, 1]))) # [1 1 1 1 1 1 1 1 1]
print(convolve(np.array([0, -1, 1, 0]), np.array([-1, 1]))) # [-1 2 -1]
print(convolve(convolve(np.arange(10), np.array([-1, 1])), np.array([-1, 1]))) # [0 0 0 0 0 0 0 0]
print(convolve(np.arange(10), np.array([-1, 2, -1]))) # [0 0 0 0 0 0 0 0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment