Skip to content

Instantly share code, notes, and snippets.

@yueyericardo
Created June 27, 2019 14:38
Show Gist options
  • Save yueyericardo/8a599ad72328acecea723cb77a8a1624 to your computer and use it in GitHub Desktop.
Save yueyericardo/8a599ad72328acecea723cb77a8a1624 to your computer and use it in GitHub Desktop.
def plot_MO(ground_true, pred, i, num_heavy_atom, num_all_atom, num_e, hide_padding=True):
if(hide_padding):
idx_zeros = np.argwhere(ground_true[i]==0)
tmp = 0
while(idx_zeros[tmp+1] - idx_zeros[tmp] > 1):
tmp += 1
first_0_idx = idx_zeros[tmp][0]
ground_true_data = ground_true[i][:first_0_idx]
pred_data = pred[i][:first_0_idx]
else:
ground_true_data = ground_true[i]
pred_data = pred[i]
x = np.linspace(1, len(ground_true_data), num=len(ground_true_data))
diff = np.absolute(ground_true_data - pred_data)
diff_large_idx = np.squeeze(np.argwhere(diff>0.04))
diff_large_data = np.take(ground_true_data, diff_large_idx)
plt.figure(figsize=(18,10))
# plot diff > 0.04
plt.plot(diff_large_idx + 1, diff_large_data, label='diff > 0.04 ({} points)'.format(len(diff_large_idx)),
marker='x', linestyle = 'None', markersize=10, markeredgewidth=1, color='magenta')
# plot ground ture
plt.plot(x, ground_true_data, label='ground true', marker='.', linestyle = '-', markersize=10,
markeredgewidth=1, color='r')
# plot pred
plt.plot(x, pred_data, label='prediction', marker='.', linestyle = '-', markersize=10,
markeredgewidth=1, color='g')
# plot some arrows
plt.arrow(num_heavy_atom[i], pred[i][num_heavy_atom[i] -1]-2, dx=0, dy=1.0, head_width=0.5,
head_length=0.5, color='red', label='num_heavy_atom')
plt.text(num_heavy_atom[i], pred[i][num_heavy_atom[i] -1]-2.6, fontsize=12, color='red',
s='heavy atoms', horizontalalignment='center')
plt.arrow(num_all_atom[i], pred[i][num_all_atom[i] -1]+2, dx=0, dy=-1.0, head_width=0.5,
head_length=0.5, color='blue', label='num_all_atom')
plt.text(num_all_atom[i], pred[i][num_all_atom[i] -1]+2.6, fontsize=12, color='blue',
s='all atoms', horizontalalignment='center')
plt.arrow(num_e[i]/2, pred[i][int(num_e[i]/2) -1]-2, dx=0, dy=1.0, head_width=0.5,
head_length=0.5, color='purple', label='num_e')
plt.text(num_e[i]/2, pred[i][int(num_e[i]/2) -1]-2.6, fontsize=12, color='purple',
s='HUMO', horizontalalignment='center')
# title and label
plt.title('Using ANI to Predict Molecular Orbitals', fontsize=18)
plt.xlabel('Molecular Orbitals', fontsize=18)
plt.ylabel('MO Energy / hartree', fontsize=18)
plt.legend(frameon=False, fontsize=16)
plt.ylim(-20, 6)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment