Created
August 16, 2023 01:29
-
-
Save IvanaGyro/c3a217b992dd856a4222762d3a94a557 to your computer and use it in GitHub Desktop.
Decompose states into matrix product states.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import reduce | |
import numpy as np | |
s = np.random.randn(8).reshape(2,2,2) | |
u, e, v = np.linalg.svd(s.reshape(2, 4), full_matrices=False) | |
a1 = u | |
s2 = np.tensordot(np.diag(e), v, axes=(-1, 0)) | |
u, e, v = np.linalg.svd(s2.reshape(4, 2), full_matrices=False) | |
a2 = u.reshape(2, 2, 2) | |
v = np.tensordot(np.diag(e), v, axes=(-1, 0)) | |
compose_s = reduce(lambda a, b: np.tensordot(a, b, axes=(-1, 0)), [a1, a2, v]) | |
print('total parameters: ', sum([a.size for a in [a1, a2, v]])) | |
print(np.allclose(compose_s, s)) | |
s = np.random.randn(16).reshape(2, 2, 2, 2) | |
u, e, v = np.linalg.svd(s.reshape(2, 8), full_matrices=False) | |
a1 = u | |
s2 = np.tensordot(np.diag(e), v, axes=(-1, 0)) | |
u, e, v = np.linalg.svd(s2.reshape(4, 4), full_matrices=False) | |
a2 = u.reshape(2, 2, 4) | |
s3 = np.tensordot(np.diag(e), v, axes=(-1, 0)) | |
u, e, v = np.linalg.svd(s3.reshape(8, 2), full_matrices=False) | |
a3 = u.reshape(4, 2, 2) | |
v = np.tensordot(np.diag(e), v, axes=(-1, 0)) | |
compose_s = reduce(lambda a, b: np.tensordot(a, b, axes=(-1, 0)), [a1, a2, a3, v]) | |
print('total parameters: ', sum([a.size for a in [a1, a2, a3, v]])) | |
print(np.allclose(compose_s, s)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment