Skip to content

Instantly share code, notes, and snippets.

@tungwaiyip
Created October 8, 2012 06:07
Show Gist options
  • Save tungwaiyip/3850967 to your computer and use it in GitHub Desktop.
Save tungwaiyip/3850967 to your computer and use it in GitHub Desktop.
Factor data structure using numpy
import itertools
import operator
import numpy as np
def compute_joint_distribution(factor_lst):
assert factor_lst
J = factor_lst[0]
for f in factor_lst[1:]:
J = J * f
J.normalize()
return J
def compute_marginal(V, factor_lst, **evdience):
factor_lst = [f.observe_evidence(**evdience) for f in factor_lst]
J = compute_joint_distribution(factor_lst)
not_V = [v for v in J.V if v not in V]
return J.marginalize(not_V)
class Factor(object):
# def __init__(self, V=[], val=None):
def __init__(self, V=[], shape=None, array=None):
if shape != None:
val = np.array(array).reshape(shape)
else:
val = np.zeros([0]*len(V))
if len(V) != val.ndim:
raise ValueError("val's number of axis should be the same as V's size: %s" % len(V))
self.V = V
self.val = val
self.degree = len(V)
# V to cardinality lookup
self.cards = dict(zip(V, val.shape))
# V to axis lookup
self.axis = dict(zip(V, range(self.degree)))
def copy(self):
return Factor(self.V[:], self.val.copy())
def iter_indices(self):
iter_index = map(xrange, self.val.shape)
return itertools.product(*iter_index)
def axes(self, V):
return [self.axis[v] for v in V]
def normalize(self):
self.val /= self.val.sum()
def __mul__(self, B):
A = self
# b_cards = B.cards - A.cards
b_cards = B.cards.copy()
for i, v in enumerate(A.V):
if v in b_cards:
if A.cards[v] != B.cards[v]:
raise ValueError("Mismatch of cardinality: {0}".format(v))
b_cards.pop(v)
# C_cards = A + disjoint b_cards
c_V = A.V[:]
c_cards = A.cards.copy()
for i, (v, c) in enumerate(b_cards.items()):
c_V.append(v)
c_cards[v] = c
C = Factor(c_V, [c_cards[v] for v in c_V], np.zeros(np.prod(c_cards.values())))
Ci_Ai = operator.itemgetter(*C.axes(A.V))
Ci_Bi = operator.itemgetter(*C.axes(B.V))
# calculate product for each val
for Ci in C.iter_indices():
Ai = Ci_Ai(Ci)
Bi = Ci_Bi(Ci)
C.val[Ci] = A.val[Ai] * B.val[Bi]
return C
def marginalize(self, V):
# marginalize the array
val = self.val
for i in sorted(self.axes(V), reverse=True):
val = val.sum(i)
bV = filter(lambda v: v not in V, self.V) # maybe Null
return Factor(bV,val)
def observe_evidence(self, **evdience):
""" set val to 0 for any item not match the evidence """
B = self.copy()
for v, n in evdience.items():
if v in B.V:
i = B.axis[v]
# indices to everything
indices = [slice(None)] * B.degree
# the lower part 0:n
indices[i] = slice(0,n)
B.val[indices] = 0
# the upper part n+1:
indices[i] = slice(n+1,None)
B.val[indices] = 0
return B
def __repr__(self):
if not self.V:
return "<Null factor>"
# headings
headings = self.V + ['Value']
col_widths = map(len, headings)
col_widths = map(lambda w: max(w,5), col_widths)
col_widths[-1] = 12
out = [
' '.join('{0:>{1}}'.format(v,w) for v, w in zip(headings, col_widths)),
' '.join('-' * w for w in col_widths),
]
# data
iter_index = map(xrange, self.val.shape)
for indices in self.iter_indices():
line = []
for w, i in zip(col_widths, indices):
line.append('{0:{1}}'.format(i,w))
v = self.val[indices]
line.append('{0:{1}f}'.format(v, col_widths[-1]))
out.append(' '.join(line))
return '\n'.join(out)
A=Factor(['X','Y','Z'],(2,3,2), np.arange(1,13)/10.0)
B=Factor(['A','Y'],(2,3),np.arange(1,7)/10.0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment