Skip to content

Instantly share code, notes, and snippets.

@VoVAllen
Last active December 6, 2018 07:52
Show Gist options
  • Save VoVAllen/a07021991f4d01f9c1ba09b498c97e1e to your computer and use it in GitHub Desktop.
Save VoVAllen/a07021991f4d01f9c1ba09b498c97e1e to your computer and use it in GitHub Desktop.
animation
def att_animation(maps_array, mode, src, tgt, head_id):
weights = [maps[mode2id[mode]][head_id] for maps in maps_array]
fig, axes = plt.subplots(1, 2)
axes[0].set_yticks(np.arange(len(src)))
axes[0].set_xticks(np.arange(len(tgt)))
axes[0].set_yticklabels(src)
axes[0].set_xticklabels(tgt)
plt.setp(axes[0].get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
colorbar=None
def weight_animate(i):
global colorbar
if colorbar:
colorbar.remove()
print("Yes")
axes[0].cla()
axes[0].set_title('heatmap')
fig.suptitle('epoch {}'.format(i))
weight = weights[i].transpose(-1, -2)
heatmap = axes[0].pcolor(weight,vmin=0, vmax=1, cmap=plt.cm.Blues)
colorbar = plt.colorbar(heatmap,ax=axes[0],fraction=0.046, pad=0.04)
axes[0].set_aspect('equal')
axes[1].cla()
axes[1].axis("off")
graph_att_head(src, tgt, weight, axes[1], 'graph')
ani = animation.FuncAnimation(fig, weight_animate, frames=len(weights), interval=500, repeat_delay=2000)
return ani
@VoVAllen
Copy link
Author

VoVAllen commented Dec 6, 2018

ani

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment