Skip to content

Instantly share code, notes, and snippets.

@bayerj
Created November 10, 2017 10:28
Show Gist options
  • Save bayerj/c1a9e684ad0e3b9b2fe1dcdb6976c688 to your computer and use it in GitHub Desktop.
Save bayerj/c1a9e684ad0e3b9b2fe1dcdb6976c688 to your computer and use it in GitHub Desktop.
import tensorflow as tf
import edward as ed
from collections import OrderedDict
def edges(rvs):
rvs = [*rvs]
edges = []
visited = set()
while True:
rv = rvs.pop()
visited.add(rv)
for parent in rv.get_parents():
rvs.append(parent)
edges.append((parent, rv))
if not rvs:
break
return edges
N = 5
loc = tf.zeros(1)
make_dist = lambda loc, name: ed.models.Normal(loc=loc, scale=1., name=name)
zs = [make_dist(loc, name='z0')]
for i in range(1, N):
zs += [make_dist(zs[i - 1], name='z{}'.format(i))]
qzs = [make_dist(loc, name='qz0')]
for i in range(1, N):
qzs += [make_dist(zs[i - 1], name='qz{}'.format(i))]
replace = OrderedDict(zip(zs, qzs))
#aggregated = make_dist(zs[-1], name='aggregate')
#aggregated = make_dist(sum(zs), name='aggregate')
aggregated = zs[-1]
aggregated_y = ed.copy(aggregated, replace, scope='copy')
print('We are replacing in {}'.format(aggregated.name))
print('')
print('Should Replacements')
print('-------------------')
for k, v in replace.items():
print('{} <- {}'.format(k.name, v.name))
print('')
print('Ancestors')
print('---------')
print(' '.join(v.name for v in aggregated_y.get_ancestors()))
this_edges = edges([aggregated_y])
for source, sink in this_edges:
print('{} -> {}'.format(source.name, sink.name))
print('')
print('Replacement Errors')
print('------------------')
for k, v in replace.items():
if k in aggregated_y.get_ancestors():
print('{} is an ancestor of aggregated_y, but should have been replaced by {}'.format(k.name, v.name))
continue
print('{} is not in there anymore'.format(k.name))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment