Created
October 8, 2012 06:07
-
-
Save tungwaiyip/3850967 to your computer and use it in GitHub Desktop.
Factor data structure using numpy
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
import operator | |
import numpy as np | |
def compute_joint_distribution(factor_lst): | |
assert factor_lst | |
J = factor_lst[0] | |
for f in factor_lst[1:]: | |
J = J * f | |
J.normalize() | |
return J | |
def compute_marginal(V, factor_lst, **evdience): | |
factor_lst = [f.observe_evidence(**evdience) for f in factor_lst] | |
J = compute_joint_distribution(factor_lst) | |
not_V = [v for v in J.V if v not in V] | |
return J.marginalize(not_V) | |
class Factor(object): | |
# def __init__(self, V=[], val=None): | |
def __init__(self, V=[], shape=None, array=None): | |
if shape != None: | |
val = np.array(array).reshape(shape) | |
else: | |
val = np.zeros([0]*len(V)) | |
if len(V) != val.ndim: | |
raise ValueError("val's number of axis should be the same as V's size: %s" % len(V)) | |
self.V = V | |
self.val = val | |
self.degree = len(V) | |
# V to cardinality lookup | |
self.cards = dict(zip(V, val.shape)) | |
# V to axis lookup | |
self.axis = dict(zip(V, range(self.degree))) | |
def copy(self): | |
return Factor(self.V[:], self.val.copy()) | |
def iter_indices(self): | |
iter_index = map(xrange, self.val.shape) | |
return itertools.product(*iter_index) | |
def axes(self, V): | |
return [self.axis[v] for v in V] | |
def normalize(self): | |
self.val /= self.val.sum() | |
def __mul__(self, B): | |
A = self | |
# b_cards = B.cards - A.cards | |
b_cards = B.cards.copy() | |
for i, v in enumerate(A.V): | |
if v in b_cards: | |
if A.cards[v] != B.cards[v]: | |
raise ValueError("Mismatch of cardinality: {0}".format(v)) | |
b_cards.pop(v) | |
# C_cards = A + disjoint b_cards | |
c_V = A.V[:] | |
c_cards = A.cards.copy() | |
for i, (v, c) in enumerate(b_cards.items()): | |
c_V.append(v) | |
c_cards[v] = c | |
C = Factor(c_V, [c_cards[v] for v in c_V], np.zeros(np.prod(c_cards.values()))) | |
Ci_Ai = operator.itemgetter(*C.axes(A.V)) | |
Ci_Bi = operator.itemgetter(*C.axes(B.V)) | |
# calculate product for each val | |
for Ci in C.iter_indices(): | |
Ai = Ci_Ai(Ci) | |
Bi = Ci_Bi(Ci) | |
C.val[Ci] = A.val[Ai] * B.val[Bi] | |
return C | |
def marginalize(self, V): | |
# marginalize the array | |
val = self.val | |
for i in sorted(self.axes(V), reverse=True): | |
val = val.sum(i) | |
bV = filter(lambda v: v not in V, self.V) # maybe Null | |
return Factor(bV,val) | |
def observe_evidence(self, **evdience): | |
""" set val to 0 for any item not match the evidence """ | |
B = self.copy() | |
for v, n in evdience.items(): | |
if v in B.V: | |
i = B.axis[v] | |
# indices to everything | |
indices = [slice(None)] * B.degree | |
# the lower part 0:n | |
indices[i] = slice(0,n) | |
B.val[indices] = 0 | |
# the upper part n+1: | |
indices[i] = slice(n+1,None) | |
B.val[indices] = 0 | |
return B | |
def __repr__(self): | |
if not self.V: | |
return "<Null factor>" | |
# headings | |
headings = self.V + ['Value'] | |
col_widths = map(len, headings) | |
col_widths = map(lambda w: max(w,5), col_widths) | |
col_widths[-1] = 12 | |
out = [ | |
' '.join('{0:>{1}}'.format(v,w) for v, w in zip(headings, col_widths)), | |
' '.join('-' * w for w in col_widths), | |
] | |
# data | |
iter_index = map(xrange, self.val.shape) | |
for indices in self.iter_indices(): | |
line = [] | |
for w, i in zip(col_widths, indices): | |
line.append('{0:{1}}'.format(i,w)) | |
v = self.val[indices] | |
line.append('{0:{1}f}'.format(v, col_widths[-1])) | |
out.append(' '.join(line)) | |
return '\n'.join(out) | |
A=Factor(['X','Y','Z'],(2,3,2), np.arange(1,13)/10.0) | |
B=Factor(['A','Y'],(2,3),np.arange(1,7)/10.0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment