Created
October 30, 2018 00:57
-
-
Save mrdrozdov/5e43f90878db05fe196b9e599c467a08 to your computer and use it in GitHub Desktop.
dp.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import argparse | |
| import json | |
| import numpy as np | |
| from tqdm import tqdm | |
| class Node(object): | |
| def __init__(self, id, parent, label, depth=None): | |
| super(Node, self).__init__() | |
| self.id = id | |
| self.parent = parent | |
| self.label = label | |
| self.depth = depth | |
| self.children = set() | |
| if self.id is not None and isinstance(self.id, str): | |
| self.id = None if self.id == 'None' else self.id | |
| if self.parent is not None and isinstance(self.parent, str): | |
| self.parent = None if self.parent == 'None' else self.parent | |
| if self.label is not None and isinstance(self.label, str): | |
| self.label = None if self.label == 'None' else self.label | |
| class Tree(object): | |
| def __init__(self): | |
| self.index = {} | |
| self.bylabel = {} | |
| self.setsbylabel = {} | |
| self.ancestors = {} | |
| @staticmethod | |
| def build(fn, max_labels=0, rng=None): | |
| if rng is None: | |
| rng = np.random | |
| tr = Tree() | |
| labeled = [] | |
| with open(fn) as f: | |
| for line in tqdm(f, desc='read'): | |
| parts = line.strip().split() # id, parent, label | |
| node = Node(id=parts[0], parent=parts[1], label=parts[2]) | |
| assert node.id not in tr.index | |
| tr.index[node.id] = node | |
| if node.label is not None: | |
| labeled.append(node) | |
| # Assign Depth | |
| for k, v in tqdm(tr.index.items(), desc='depth'): | |
| v.depth = tr.get_depth(v) | |
| # Assign Children | |
| for k, v in tqdm(tr.index.items(), desc='children'): | |
| if v.parent is not None: | |
| tr.index[v.parent].children.add(v.id) | |
| if max_labels > 0: | |
| rng.shuffle(labeled) | |
| # Limit | |
| print('before', len(tr.index)) | |
| toremove = labeled[max_labels:] | |
| for x in toremove: | |
| tr.remove(x) | |
| labeled = labeled[:max_labels] | |
| print('after', len(tr.index)) | |
| # Assign Labels | |
| for k, v in tqdm(tr.index.items(), desc='labels'): | |
| if v.label is not None: | |
| tr.bylabel.setdefault(v.label, []).append(v) | |
| for k in sorted(tr.bylabel.keys()): | |
| nodes = tr.bylabel[k] | |
| if len(nodes) == 1: | |
| for x in nodes: | |
| tr.remove(x) | |
| del tr.bylabel[k] | |
| print('filtered-small-label', len(tr.index)) | |
| for k in sorted(tr.bylabel.keys()): | |
| nodes = tr.bylabel[k] | |
| tr.setsbylabel[k] = set(x.id for x in nodes) | |
| # Assign Ancestor | |
| for v in tr.index.values(): | |
| if v.parent is None: | |
| root = v | |
| pbar = tqdm(desc='lca') | |
| tr.topdown_lca(root, pbar) | |
| pbar.close() | |
| return tr | |
| def purity(self, N, rng=None): | |
| if rng is None: | |
| rng = np.random | |
| keys = list(self.bylabel.keys()) | |
| sizes = [len(self.bylabel[k]) for k in keys] | |
| ntotal = sum(sizes) | |
| dist = [n/ntotal for n in sizes] | |
| purity_lst = [] | |
| for i in tqdm(range(N), desc='purity'): | |
| k = rng.choice(keys, p=dist) | |
| lst = self.bylabel[k] | |
| x1, x2 = rng.choice(lst, replace=False, size=2) | |
| if x1.id < x2.id: | |
| key = (x1.id, x2.id) | |
| else: | |
| key = (x2.id, x1.id) | |
| a = self.ancestors[key] | |
| s1 = a.leaves # Prediction | |
| s2 = self.setsbylabel[x1.label] # Ground Truth | |
| example_purity = len(set.intersection(s1, s2)) / len(s1) | |
| purity_lst.append(example_purity) | |
| mean_purity = np.array(purity_lst).mean() | |
| print('mean-purity', mean_purity) | |
| def remove(self, x): | |
| parent = self.index[x.parent] | |
| parent.children.remove(x.id) | |
| del self.index[x.id] | |
| if len(parent.children) == 0: | |
| self.remove(parent) | |
| def get_depth(self, x): | |
| if x.parent is None: | |
| return 0 | |
| parent = self.index[x.parent] | |
| if parent.depth is not None: | |
| return parent.depth + 1 | |
| return self.get_depth(parent) | |
| def topdown_lca(self, root, pbar=None): | |
| if len(root.children) == 0: | |
| return [root] | |
| node_lst = [] | |
| for c in root.children: | |
| nodes = self.topdown_lca(self.index[c], pbar=pbar) # Returns all leaves. | |
| node_lst.append(nodes) | |
| ret = [] | |
| for xs in node_lst: | |
| ret += xs | |
| # Also assigns leaves. | |
| root.leaves = set(x.id for x in ret) | |
| # Save the common ancestor. | |
| for i in range(len(node_lst)): | |
| lst1 = node_lst[i] | |
| for j in range(len(node_lst)): | |
| if i == j: | |
| continue | |
| lst2 = node_lst[j] | |
| for x1 in lst1: | |
| for x2 in lst2: | |
| if x1.id < x2.id: | |
| key = (x1.id, x2.id) | |
| else: | |
| key = (x2.id, x1.id) | |
| self.ancestors[key] = root | |
| if pbar is not None: | |
| pbar.update() | |
| return ret | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--tree', default='./experiments_out/aloi/Perch/run_1/tree.tsv', type=str) | |
| parser.add_argument('--max', default=0, type=int) | |
| parser.add_argument('--N', default=100, type=int) | |
| parser.add_argument('--rng_build', default=None, type=int) | |
| parser.add_argument('--rng_purity', default=None, type=int) | |
| options = parser.parse_args() | |
| if options.rng_build is None: | |
| options.rng_build = np.random.randint(0, 1e8) | |
| if options.rng_purity is None: | |
| options.rng_purity = np.random.randint(0, 1e8) | |
| print(json.dumps(options.__dict__, sort_keys=True, indent=4)) | |
| # tr = Tree.build(options.tree, max_labels=options.max) | |
| # tr.purity(N=options.N) | |
| tr = Tree.build(options.tree, max_labels=options.max, rng=np.random.RandomState(seed=options.rng_build)) | |
| tr.purity(N=options.N, rng=np.random.RandomState(seed=options.rng_purity)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment