Last active
July 17, 2017 07:48
-
-
Save minhlab/d8d53156045a16762c0a677484975e1c to your computer and use it in GitHub Desktop.
Testing the validity of sanity check proposed in Batchkarov et al. (2016)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from scipy.stats import spearmanr, pearsonr | |
from matplotlib import pyplot as pl | |
import sys | |
if __name__ == '__main__': | |
repeats = int(sys.argv[1]) if len(sys.argv) >= 2 else 5 # change this and see what happens | |
dim = 50 | |
#sizes = {'simlex': 999, 'men': 3000, 'mc': 30, 'rg': 65, 'ws353': 353} | |
sizes = {'men': (3000, '#0066ff', '#4d94ff'), 'mc': (30, '#ff3300', '#ffd6cc80'), 'rg': (65, '#00ff00', '#ccffcc80')} | |
for name, (size, line_color, shade_color) in sizes.items(): | |
gold_pairs = np.random.rand(2, size, dim)*2-1 | |
gold_scores = (gold_pairs[0]*gold_pairs[1]).sum(axis=1) | |
sys_pairs = 0.5*(gold_pairs + np.random.rand(2, size, dim)*1.2-0.6) | |
print('Mean norm of system vectors: %.2f' %np.linalg.norm(sys_pairs, axis=2).ravel().mean()) | |
n_vals = np.linspace(0, 3, 15) | |
rho_means = np.zeros(15) | |
rho_stds = np.zeros(15) | |
for i, n in enumerate(n_vals): | |
rhos = [] | |
perturbed_pairs = sys_pairs[:,np.newaxis,:,:] + np.random.rand(2, repeats, size, dim)*n*2-n | |
perturbed_scores = (perturbed_pairs[0]*perturbed_pairs[1]).sum(axis=2) | |
for j in range(repeats): | |
rhos.append(spearmanr(gold_scores, perturbed_scores[j])[0]) | |
rho_means[i] = np.mean(rhos) | |
rho_stds[i] = np.std(rhos) | |
pl.plot(n_vals, rho_means, 'k-', label=name, color=line_color) | |
pl.fill_between(n_vals, rho_means-rho_stds, rho_means+rho_stds, color=shade_color) | |
pl.legend() | |
pl.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment