Skip to content

Instantly share code, notes, and snippets.

@3h4
Created November 18, 2016 14:47
Show Gist options
  • Save 3h4/b57d49f2cfc77f683b2fe18ec7eef498 to your computer and use it in GitHub Desktop.
Save 3h4/b57d49f2cfc77f683b2fe18ec7eef498 to your computer and use it in GitHub Desktop.
def plot(loss_list, predictions_series, batchX, batchY):
plt.subplot(2, 3, 1)
plt.cla()
plt.plot(loss_list)
for batch_series_idx in range(5):
one_hot_output_series = np.array(predictions_series)[:, batch_series_idx, :]
single_output_series = np.array([(1 if out[0] < 0.5 else 0) for out in one_hot_output_series])
plt.subplot(2, 3, batch_series_idx + 2)
plt.cla()
plt.axis([0, truncated_backprop_length, 0, 2])
left_offset = range(truncated_backprop_length)
plt.bar(left_offset, batchX[batch_series_idx, :], width=1, color="blue")
plt.bar(left_offset, batchY[batch_series_idx, :] * 0.5, width=1, color="red")
plt.bar(left_offset, single_output_series * 0.3, width=1, color="green")
plt.draw()
plt.pause(0.0001)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment