Skip to content

Instantly share code, notes, and snippets.

@mrdrozdov
Created October 30, 2018 00:57
Show Gist options
  • Select an option

  • Save mrdrozdov/5e43f90878db05fe196b9e599c467a08 to your computer and use it in GitHub Desktop.

Select an option

Save mrdrozdov/5e43f90878db05fe196b9e599c467a08 to your computer and use it in GitHub Desktop.
dp.py
import argparse
import json
import numpy as np
from tqdm import tqdm
class Node(object):
def __init__(self, id, parent, label, depth=None):
super(Node, self).__init__()
self.id = id
self.parent = parent
self.label = label
self.depth = depth
self.children = set()
if self.id is not None and isinstance(self.id, str):
self.id = None if self.id == 'None' else self.id
if self.parent is not None and isinstance(self.parent, str):
self.parent = None if self.parent == 'None' else self.parent
if self.label is not None and isinstance(self.label, str):
self.label = None if self.label == 'None' else self.label
class Tree(object):
def __init__(self):
self.index = {}
self.bylabel = {}
self.setsbylabel = {}
self.ancestors = {}
@staticmethod
def build(fn, max_labels=0, rng=None):
if rng is None:
rng = np.random
tr = Tree()
labeled = []
with open(fn) as f:
for line in tqdm(f, desc='read'):
parts = line.strip().split() # id, parent, label
node = Node(id=parts[0], parent=parts[1], label=parts[2])
assert node.id not in tr.index
tr.index[node.id] = node
if node.label is not None:
labeled.append(node)
# Assign Depth
for k, v in tqdm(tr.index.items(), desc='depth'):
v.depth = tr.get_depth(v)
# Assign Children
for k, v in tqdm(tr.index.items(), desc='children'):
if v.parent is not None:
tr.index[v.parent].children.add(v.id)
if max_labels > 0:
rng.shuffle(labeled)
# Limit
print('before', len(tr.index))
toremove = labeled[max_labels:]
for x in toremove:
tr.remove(x)
labeled = labeled[:max_labels]
print('after', len(tr.index))
# Assign Labels
for k, v in tqdm(tr.index.items(), desc='labels'):
if v.label is not None:
tr.bylabel.setdefault(v.label, []).append(v)
for k in sorted(tr.bylabel.keys()):
nodes = tr.bylabel[k]
if len(nodes) == 1:
for x in nodes:
tr.remove(x)
del tr.bylabel[k]
print('filtered-small-label', len(tr.index))
for k in sorted(tr.bylabel.keys()):
nodes = tr.bylabel[k]
tr.setsbylabel[k] = set(x.id for x in nodes)
# Assign Ancestor
for v in tr.index.values():
if v.parent is None:
root = v
pbar = tqdm(desc='lca')
tr.topdown_lca(root, pbar)
pbar.close()
return tr
def purity(self, N, rng=None):
if rng is None:
rng = np.random
keys = list(self.bylabel.keys())
sizes = [len(self.bylabel[k]) for k in keys]
ntotal = sum(sizes)
dist = [n/ntotal for n in sizes]
purity_lst = []
for i in tqdm(range(N), desc='purity'):
k = rng.choice(keys, p=dist)
lst = self.bylabel[k]
x1, x2 = rng.choice(lst, replace=False, size=2)
if x1.id < x2.id:
key = (x1.id, x2.id)
else:
key = (x2.id, x1.id)
a = self.ancestors[key]
s1 = a.leaves # Prediction
s2 = self.setsbylabel[x1.label] # Ground Truth
example_purity = len(set.intersection(s1, s2)) / len(s1)
purity_lst.append(example_purity)
mean_purity = np.array(purity_lst).mean()
print('mean-purity', mean_purity)
def remove(self, x):
parent = self.index[x.parent]
parent.children.remove(x.id)
del self.index[x.id]
if len(parent.children) == 0:
self.remove(parent)
def get_depth(self, x):
if x.parent is None:
return 0
parent = self.index[x.parent]
if parent.depth is not None:
return parent.depth + 1
return self.get_depth(parent)
def topdown_lca(self, root, pbar=None):
if len(root.children) == 0:
return [root]
node_lst = []
for c in root.children:
nodes = self.topdown_lca(self.index[c], pbar=pbar) # Returns all leaves.
node_lst.append(nodes)
ret = []
for xs in node_lst:
ret += xs
# Also assigns leaves.
root.leaves = set(x.id for x in ret)
# Save the common ancestor.
for i in range(len(node_lst)):
lst1 = node_lst[i]
for j in range(len(node_lst)):
if i == j:
continue
lst2 = node_lst[j]
for x1 in lst1:
for x2 in lst2:
if x1.id < x2.id:
key = (x1.id, x2.id)
else:
key = (x2.id, x1.id)
self.ancestors[key] = root
if pbar is not None:
pbar.update()
return ret
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--tree', default='./experiments_out/aloi/Perch/run_1/tree.tsv', type=str)
parser.add_argument('--max', default=0, type=int)
parser.add_argument('--N', default=100, type=int)
parser.add_argument('--rng_build', default=None, type=int)
parser.add_argument('--rng_purity', default=None, type=int)
options = parser.parse_args()
if options.rng_build is None:
options.rng_build = np.random.randint(0, 1e8)
if options.rng_purity is None:
options.rng_purity = np.random.randint(0, 1e8)
print(json.dumps(options.__dict__, sort_keys=True, indent=4))
# tr = Tree.build(options.tree, max_labels=options.max)
# tr.purity(N=options.N)
tr = Tree.build(options.tree, max_labels=options.max, rng=np.random.RandomState(seed=options.rng_build))
tr.purity(N=options.N, rng=np.random.RandomState(seed=options.rng_purity))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment