Last active
October 4, 2019 07:53
-
-
Save kawine/647747cef4f53ce1896f2a46b1402a61 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from scipy.stats import pearsonr, ttest_ind | |
from scipy.spatial.distance import cosine | |
ANALOGY_VOCAB = set([]) # specify your vocabulary | |
class pair2joint(object): | |
"""Load co-occurrence counts and calculate PMI and csPMI.""" | |
def __init__(self, fn='counts.txt'): | |
""" | |
counts.txt should be of the format `word1 word2 count(word1,word2)' per line. | |
""" | |
self.joint = {} | |
self.marginal = {} | |
for line in open(fn): | |
a, b, freq = line.strip().split() | |
freq = int(freq) | |
if a in ANALOGY_VOCAB and b in ANALOGY_VOCAB: | |
self.joint[(a,b)] = freq | |
self.joint[(b,a)] = freq | |
self.marginal[a] = self.marginal.get(a,0) + freq | |
self.marginal[b] = self.marginal.get(b,0) + freq | |
# total number of word pairs = sum(marginal.values()) / 2 | |
# since there are two entries in self.marginal for each word pair | |
self.total = sum(self.marginal.values()) / 2.0 | |
def PMI(self, a, b): | |
return np.log(self.joint[(a,b)]) + np.log(self.total) - np.log(self.marginal[a]) - np.log(self.marginal[b]) | |
def csPMI(self, a, b): | |
return 2 * np.log(self.joint[(a,b)]) - np.log(self.marginal[a]) - np.log(self.marginal[b]) | |
def __getitem__(self, x): | |
if len(x) == 2: | |
return self.joint[x] / self.total | |
else: | |
return self.marginal[x] / self.total | |
FREQUENCIES = pair2joint() | |
def calc_stats(): | |
""" | |
Calculate stats for analogy categories given in the format of Mikolov et al. (questions-words.txt). | |
Word pairs from capital-world, for example, would include ('Paris', 'France'), ('Berlin', 'Germany'), etc. | |
Statistics are calculated for word pairs from that category. | |
""" | |
csPMI_values = {} | |
PMI_values = {} | |
joint_counts = {} | |
for line in open('questions-words.txt'): | |
if line[0] == ':': | |
category = line[1:].strip() | |
csPMI_values[category] = [] | |
PMI_values[category] = [] | |
joint_counts[category] = [] | |
else: | |
a,b,c,d = line.strip().split() | |
try: | |
csPMI_values[category].append(FREQUENCIES.csPMI(a,b)) | |
PMI_values[category].append(FREQUENCIES.PMI(a,b)) | |
joint_counts[category].append(FREQUENCIES.joint[(a,b)]) | |
except KeyError: | |
pass | |
try: | |
csPMI_values[category].append(FREQUENCIES.csPMI(c,d)) | |
PMI_values[category].append(FREQUENCIES.PMI(c,d)) | |
joint_counts[category].append(FREQUENCIES.joint[(c,d)]) | |
except KeyError: | |
pass | |
for c1 in csPMI_values: | |
sample = list(set(csPMI_values[c1])) | |
print(c1, np.mean(sample), np.mean(PMI_values[c1]), np.median(joint_counts[c1]), np.var(sample)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment