Skip to content

Instantly share code, notes, and snippets.

@tamlt2704
Created June 1, 2018 11:50
Show Gist options
  • Save tamlt2704/3841322984cb00055e2f3edf88733706 to your computer and use it in GitHub Desktop.
Save tamlt2704/3841322984cb00055e2f3edf88733706 to your computer and use it in GitHub Desktop.
from __future__ import division
import random
import numpy as np
import matplotlib.pyplot as plt
import imageio
import os
%matplotlib inline
M, P, N = 2, 3, 5
A = np.random.randint(0, 2, (M, P))
B = np.random.randint(0, 2, (P, N))
nb_rows, nb_cols = M + P, P + N
rect_A = 0, 0, P / nb_cols - 0.05, M / nb_rows - 0.05
rect_B = P / nb_cols + 0.05, M / nb_rows + 0.05, N / nb_cols - 0.05, P / nb_rows - 0.05
rect_C = P / nb_cols + 0.05, 0, N / nb_cols - 0.05, M / nb_rows - 0.05
def plot_matrix(fig, data, rect, title, highlight_args=None, xvline=None, yhline=None):
nb_rows, nb_cols = data.shape
ax = fig.add_axes(rect)
d = np.zeros((nb_rows, nb_cols))
if highlight_args:
x, y = highlight_args
d[x, y] = 1
ax.imshow(d, aspect='auto')
ax.grid()
ax.set_yticks([x-0.5 for x in range(nb_rows)])
ax.set_xticks([x-0.5 for x in range(nb_cols)])
ax.set_xticklabels([])
ax.set_yticklabels([])
# annotate data
for ii in range(nb_rows):
for jj in range(nb_cols):
ax.annotate(data[ii,jj],xy=(jj,ii), fontsize=12, color='red', xycoords='data')
ax.set_title(title)
if xvline is not None:
if xvline == 0:
xvline = 0.15
ax.axvline(x=xvline, ls=':', lw=3)
if yhline is not None:
if yhline == 0:
yhline = 0.15
ax.axhline(y=yhline, ls=':', lw=3)
def multiple_matrix_naive_solution(A, B):
C = np.zeros((M, N))
ctn = 0
for i in range(M):
for j in range(N):
for k in range(P):
C[i,j] += A[i, k] * B[k, j]
C = C.astype(int)
fig = plt.figure()
plot_matrix(fig, A, rect_A, 'A', (i, k), xvline=None, yhline=i)
plot_matrix(fig, B, rect_B, 'B', (k ,j), xvline=j, yhline=None)
plot_matrix(fig, C, rect_C, 'C', (i, j))
fig.set_size_inches(nb_cols, nb_rows)
fig.tight_layout()
fig.savefig('matrix/s1_{0:02d}'.format(ctn))
plt.close()
ctn += 1
return C.astype(int)
C = multiple_matrix_naive_solution(A, B)
filenames = [fn for fn in os.listdir('matrix') if fn.startswith('s1')]
images = []
for filename in filenames:
images.append(imageio.imread('matrix/{}'.format(filename)))
imageio.mimsave('matrix/movie.gif', images, duration=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment