Created
December 21, 2021 13:50
-
-
Save noam1023/c7750a104d97afb7e2c7aa1474b1f29b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# check the results of DSLAB HW2 | |
# expected input: | |
# check_hw2.py industry2cluster_123456789_987654312.csv company2cluster_123456789_987654312.csv golden.csv | |
import random | |
import sys | |
import csv | |
def read_industries(fn) -> dict: | |
"""read the mapping industry -> cluster""" | |
r = {} | |
with open(fn, 'r') as input: | |
reader = csv.reader(input, quoting=csv.QUOTE_ALL) | |
next(reader) # skip header line | |
for row in reader: | |
industry, id = row | |
id = int(id.strip()) | |
r[industry] = id | |
return r | |
def read_golden_industry(fn) -> dict: | |
""" | |
read the CSV file containing <int companyID> -> <string (with potential coma) industry name> | |
:param fn: CSV file name | |
:return: dict{ companyId -> industry name) | |
""" | |
r = {} | |
with open(fn, 'r', newline='') as input: | |
reader = csv.reader(input, quoting=csv.QUOTE_ALL) | |
for row in reader: | |
companyId, industry = row | |
r[companyId] = industry | |
return r | |
def incr(aDict, key ): | |
if key not in aDict: | |
aDict[key] = 0 | |
else: | |
aDict[key] += 1 | |
def check(industry2cluster_fn, company2_cluster_fn, golden_company_industry_fn): | |
""" | |
Given the data files from the user, compare the results using the true industry of each company. | |
:return: number of correct (relative to the golden values of industry) of mapping a company to industry to the SAME cluster. | |
""" | |
num_matching = 0 | |
golden_company_industry = read_golden_industry(golden_company_industry_fn) | |
industry2cluster = read_industries(industry2cluster_fn) | |
pred_company2cluster = [dict() for _ in range(21)] # can't use [{}]*21 because it will be the same dict... | |
with open(company2_cluster_fn, 'r') as company_csv: | |
reader = csv.reader(company_csv, quoting=csv.QUOTE_ALL) | |
next(reader) # skip header line | |
total_data_points = 0 | |
for row in reader: | |
comp_id, pred_cluster_id = row | |
pred_cluster_id = int(pred_cluster_id.strip()) | |
total_data_points += 1 | |
try: | |
true_industry = golden_company_industry[comp_id] | |
except KeyError: | |
print("company ID %d not found" % int(comp_id)) | |
continue | |
clusterID_from_industry = industry2cluster[true_industry] | |
num_matching += clusterID_from_industry == pred_cluster_id | |
incr(pred_company2cluster[pred_cluster_id],true_industry) | |
# The following line is a 'circular reasoning' since we use | |
# the true_industry both as key and value | |
# incr(pred_company2cluster[clusterID_from_industry], true_industry) | |
success_pct = 100*num_matching/total_data_points | |
print("correctly identified: %d = %d%%"% (num_matching, int(success_pct)) ) | |
factor = check_cluster_sizes(pred_company2cluster[1:]) # get rid of the zero-th element which was just to keep indexing nice | |
factor = max(factor,0.7) # don't be too cruel | |
print("Penalty factor due to cluster sizes: %d%%" % (factor*100)) | |
return success_pct * factor | |
def check_cluster_sizes(pred: list): | |
""" we want that for each of the 20 clusters: | |
number of industries be in the range 4 to 15 (which is 3% and 10% of 147 industries) | |
Each cluster contains company IDs (there are 200K data points in the test data). | |
For each cluster, we want to verify that there are between 4 and 15 industries. | |
This is a bit confusing, since the cluster size counts how many companies are in this cluster, | |
and the range requirement is on a feature of the company | |
:param: list of dict . For each cluster we keep a dictionary of industry-> count of companies that REALLY have this industry | |
:return: factor in [0,1.0] according to how good this requirement is fulfilled. | |
""" | |
penalty_scale = 0.005 # reduce the score by N% for each 1% outlier ( e.g. if 12% -> 12-10 == 2 --> 2*N ) | |
# create a list with 20 entries: for each cluster, how many industries are represented in it. | |
cluster_num_industries = [ len(ind.keys()) for ind in pred] | |
print("Cluster sizes:", cluster_num_industries) | |
outliers = [ x for x in cluster_num_industries if x < 4 or x > 15] | |
if len(outliers) > 0: | |
print("outliers %", outliers) | |
penalty = sum([ 4-t if t < 4 else t-15 for t in outliers]) * penalty_scale | |
# # to get rough idea how the cluster sizes are distributed | |
# import numpy as np | |
# mean = np.mean(cluster_size) | |
# std = np.std(cluster_size) | |
# if std/mean > 0.1 : | |
# print("cluster stats: mean=%d std=%d" % (mean,std)) | |
return 1.0 - penalty | |
def _gen_industry2cluster(golden_company_industry_fn): | |
""" generate a fake result industry (string) -> cluster (int [1..20])""" | |
golden_company_industry = read_golden_industry(golden_company_industry_fn) | |
industries = set(golden_company_industry.values()) | |
print("found %d industries" % len(industries)) | |
with open('industry2cluster.csv', 'w') as fout: | |
for name in industries: | |
fout.write('"%s",%d\n' % (name, random.randint(1, 20))) | |
def _gen_company2cluster(num_companies): | |
""" create fake table of company ID -> cluster ID""" | |
import random | |
offset = 1500000 | |
with open('company2cluster.csv', 'w') as fout: | |
for i in range(num_companies): | |
fout.write('%d,%d\n' % (i + offset, random.randint(1, 20))) | |
def is_sane(row): | |
"""try to identify malformed lines. The labeled.csv contains a lot of garbage""" | |
try: | |
int(row[0]) | |
except ValueError: | |
return False | |
return len(row) == 8 | |
def _gen_golden_company_industry(name): | |
"create a CSV file containing only the columns of companyID , and industryID (col ID 0 and 6)" | |
from io import StringIO | |
output = StringIO() | |
with open(name, 'r', newline='') as fin: | |
reader = csv.reader(fin, quoting=csv.QUOTE_ALL) | |
next(reader) # skip header line | |
i = 0 | |
while True: | |
try: | |
for row in reader: | |
i = i + 1 | |
try: | |
if is_sane(row): | |
output.write('%s,"%s"\n' % (row[0], row[6])) | |
else: | |
print("line %d is malformed. skipping: %s" % (i, row)) | |
except IndexError: | |
print("line %d caused index error" % i) | |
break | |
except csv.Error: | |
print("skipping line %d with error in it" % i) | |
with open('golden.csv', 'w') as ofile: | |
ofile.write(output.getvalue()) | |
if __name__ == "__main__": | |
# _gen_industry2cluster(sys.argv[3]) | |
# _gen_company2cluster(200000) | |
#_gen_golden_company_industry("/home/cnoam/courses_teaching/94290/labeled.csv") | |
# usage: check_hw2 industry_cluster.csv company2cluster.csv true_labels_starting_at_150k.csv | |
score = check(sys.argv[1], sys.argv[2], sys.argv[3]) | |
# any error above will throw, causing exit with error, so no need to do anything | |
print("score=%d" % score) | |
exit(0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment