Skip to content

Instantly share code, notes, and snippets.

@ckrapu
Last active February 22, 2021 16:50
Show Gist options
  • Save ckrapu/befc49378a685b1a323db1f3f367bcf8 to your computer and use it in GitHub Desktop.
Save ckrapu/befc49378a685b1a323db1f3f367bcf8 to your computer and use it in GitHub Desktop.
vectorized-gibbs-sampler
def vector_randint(p, axis=1):
u = np.random.uniform(size=p.shape)
v = u*p
return np.argmax(v, axis=axis)
def batch_gibbs_sample(parents, children, point, parent_factors, is_observed, extra_factors=None):
for var in parents.keys():
should_update = is_observed[:, var]
# Extract values of parent variables
# If no parents, we have a 1D marginal distribution
if len(parents[var]) == 0:
factors_long = parent_factors[var][np.newaxis, :]
else:
factors_long = parent_factors[var][[point[:, p] for p in parents[var]]]
for child in children[var]:
# Pare down N-D children factors to only range over one parent dimension
# Build up slice object, iterating over parents of a child and finally the child
slice_args = []
for p in parents[child]:
if p != var:
slice_args += [point[:, p]]
else:
slice_args += [None]
slice_args += [point[:, child]]
# Due to the conventions of slicing with None and arrays, the contribution
# from the first parent (in the topological order) needs to be transposed.
child_factors_long = parent_factors[child][sequence_slice(slice_args)]
is_this_parent_first = parents[child][0] == var
if is_this_parent_first:
child_factors_long = child_factors_long.T
factors_long = factors_long + child_factors_long
if extra_factors:
factors_long = factors_long + extra_factors[var]
p = softmax(factors_long[should_update], axis=1)
point[should_update, var] = vector_randint(p, axis=1)
return point
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment