Skip to content

Instantly share code, notes, and snippets.

@eileen-code4fun
Created January 21, 2022 05:55
Show Gist options
  • Save eileen-code4fun/b610ab2c0efe727ded7708845c584d99 to your computer and use it in GitHub Desktop.
Save eileen-code4fun/b610ab2c0efe727ded7708845c584d99 to your computer and use it in GitHub Desktop.
Plot Attention
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
def plot_attention(attention, spa, eng):
spa = standardize(spa).numpy().decode().split()
eng = standardize(eng).numpy().decode().split()[1:]
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(1, 1, 1)
attention = tf.squeeze(attention).numpy()
ax.matshow(attention, cmap='viridis', vmin=0.0)
fontdict = {'fontsize': 14}
ax.set_xticklabels([''] + spa, fontdict=fontdict, rotation=90)
ax.set_yticklabels([''] + eng, fontdict=fontdict)
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
ax.set_xlabel('Input text')
ax.set_ylabel('Output text')
plt.suptitle('Attention weights')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment