Skip to content

Instantly share code, notes, and snippets.

@rahulvigneswaran
Created January 9, 2020 08:37
Show Gist options
  • Save rahulvigneswaran/16fb15b7c909d20793c031f800739597 to your computer and use it in GitHub Desktop.
Save rahulvigneswaran/16fb15b7c909d20793c031f800739597 to your computer and use it in GitHub Desktop.
import os
import numpy as np
from scipy.stats import wasserstein_distance, energy_distance
from matplotlib import pyplot as plt
epoch_name = "epoch_1_"
epoch = 1
a = os.getcwd()
root, dirs, files = next(os.walk(a))
dirs.sort()
def norm_vals (arr) :
for ei in range(len(arr)) :
if arr[ei] == 0: arr[ei] = arr[ei] + min(arr)
return arr
for i in dirs :
if 'model' in i :
temp = 1
else :
dirs.remove(i)
for i in dirs :
full = np.load(f"{os.getcwd()}/{i}/ckpt/training_eigenspectrum_full.npy")
full_eigval = full[::2]
full_eigval_density = full[1::2]
a = f"{os.getcwd()}/{i}/ckpt"
root, dirs2, files = next(os.walk(a))
dirs2.sort()
print(f"\nFolder in Progress: {i}")
print("==================================\n")
distance = []
for j in files :
if 'layer' in j :
temp1 = 1
else :
files.remove(j)
tick = []
for j in files :
if epoch_name in j:
layer = np.load(f"{os.getcwd()}/{i}/ckpt/{j}")
#full_eigval_density[epoch] = ((1-0)/(max(full_eigval_density[epoch])-min(full_eigval_density[epoch]))*(full_eigval_density[epoch]- min(full_eigval_density[epoch]))+ 0)
layer_eigval = layer['eigval']
layer_density = layer['eigval_density']
#layer_density = ((1-0)/(max(layer_density) - min(layer_density))*(layer_density - min(layer_density))+ 0)
#Convert to log scale
full_eigval_density[epoch] = norm_vals(full_eigval_density[epoch])
layer_density = norm_vals(layer_density)
#Normalize between 0 and 1
full_eigval[epoch] = ((1-0)/(max(full_eigval[epoch])-min(full_eigval[epoch]))*(full_eigval[epoch]- min(full_eigval[epoch]))+ 0)
layer_eigval = ((1-0)/(max(layer_eigval)-min(layer_eigval))*(layer_eigval- min(layer_eigval))+ 0)
b = wasserstein_distance(full_eigval[epoch],layer_eigval, full_eigval_density[epoch],layer_density)
distance = np.hstack((distance,b))
print(f"Layer : {j} | Wasserstein Distance : {b}")
c = j.split(epoch_name)
tick = np.hstack((tick,c[1]))
plt.plot(np.arange(0,len(tick)), distance)
plt.xticks(np.arange(0,len(tick)),tick,rotation='vertical')
plt.title(f"(Epoch : {epoch} | {i})")
plt.ylim(0,100)
plt.savefig(f"{os.getcwd()}/plots/{i}_epoch_{epoch}.png")
plt.close()
((1-0)/(max(full_eigval[epoch])-min(full_eigval[epoch]))*(full_eigval[epoch]- min(full_eigval[epoch]))+ 0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment