Skip to content

Instantly share code, notes, and snippets.

@minhlab
Last active July 17, 2017 07:48
Show Gist options
  • Save minhlab/d8d53156045a16762c0a677484975e1c to your computer and use it in GitHub Desktop.
Save minhlab/d8d53156045a16762c0a677484975e1c to your computer and use it in GitHub Desktop.
Testing the validity of sanity check proposed in Batchkarov et al. (2016)
import numpy as np
from scipy.stats import spearmanr, pearsonr
from matplotlib import pyplot as pl
import sys
if __name__ == '__main__':
repeats = int(sys.argv[1]) if len(sys.argv) >= 2 else 5 # change this and see what happens
dim = 50
#sizes = {'simlex': 999, 'men': 3000, 'mc': 30, 'rg': 65, 'ws353': 353}
sizes = {'men': (3000, '#0066ff', '#4d94ff'), 'mc': (30, '#ff3300', '#ffd6cc80'), 'rg': (65, '#00ff00', '#ccffcc80')}
for name, (size, line_color, shade_color) in sizes.items():
gold_pairs = np.random.rand(2, size, dim)*2-1
gold_scores = (gold_pairs[0]*gold_pairs[1]).sum(axis=1)
sys_pairs = 0.5*(gold_pairs + np.random.rand(2, size, dim)*1.2-0.6)
print('Mean norm of system vectors: %.2f' %np.linalg.norm(sys_pairs, axis=2).ravel().mean())
n_vals = np.linspace(0, 3, 15)
rho_means = np.zeros(15)
rho_stds = np.zeros(15)
for i, n in enumerate(n_vals):
rhos = []
perturbed_pairs = sys_pairs[:,np.newaxis,:,:] + np.random.rand(2, repeats, size, dim)*n*2-n
perturbed_scores = (perturbed_pairs[0]*perturbed_pairs[1]).sum(axis=2)
for j in range(repeats):
rhos.append(spearmanr(gold_scores, perturbed_scores[j])[0])
rho_means[i] = np.mean(rhos)
rho_stds[i] = np.std(rhos)
pl.plot(n_vals, rho_means, 'k-', label=name, color=line_color)
pl.fill_between(n_vals, rho_means-rho_stds, rho_means+rho_stds, color=shade_color)
pl.legend()
pl.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment