Skip to content

Instantly share code, notes, and snippets.

@Prince781
Created December 9, 2019 23:22
Show Gist options
  • Save Prince781/d250d2e449031f158c79879bc33b03cf to your computer and use it in GitHub Desktop.
Save Prince781/d250d2e449031f158c79879bc33b03cf to your computer and use it in GitHub Desktop.
discrete random variable mini-library
import random
from functools import reduce
import numbers
from math import sqrt, floor, ceil
# RV1 + RV2 = new RV
# same idea for other arithmetic ops
class RV:
def __init__(self, values):
"""Creates a new discrete random variable."""
assert values is not None
assert isinstance(values, dict) or isinstance(values, numbers.Real) \
or isinstance(values, tuple)
if isinstance(values, dict):
assert values
assert sum(values.values()) == 1.0
self.__probs = values
self.__rand = random.Random()
self.__state = self.__rand.getstate()
self.__const = None
self.__op = None
self.__rv1 = None
self.__rv2 = None
self.__testval = self.observe()
elif isinstance(values, numbers.Real):
self.__probs = None
self.__rand = None
self.__state = None
self.__const = values
self.__op = None
self.__rv1 = None
self.__rv2 = None
self.__testval = self.observe()
else:
# assert isinstance(values, tuple)
assert len(values) == 3
rv1, op, rv2 = values
assert isinstance(rv1, RV)
assert isinstance(rv2, RV) or rv2 is None
assert op
self.__probs = None
self.__rand = None
self.__state = None
self.__const = None
self.__op = op
self.__rv1 = rv1
self.__rv2 = rv2
self.__testval = self.observe()
def __hash__(self):
return id(self)
def sample(self):
"""Samples this random variable."""
if self.__rand is not None:
self.__state = self.__rand.getstate()
r = self.__rand.random()
for val, prob in self.__probs.items():
if r <= prob:
return val
r = r - prob
assert False, "Should not be reached"
elif self.__const is not None:
return self.__const
else:
for d in self.__dependencies():
d.sample()
x = self.__rv1.observe()
if self.__rv2 is not None:
y = self.__rv2.observe()
return self.__op(x, y)
else:
return self.__op(x)
def observe(self):
"""Observes the current value without sampling."""
if self.__rand is not None:
self.__rand.setstate(self.__state)
r = self.__rand.random()
for val, prob in self.__probs.items():
if r <= prob:
return val
r = r - prob
assert False, "Should not be reached"
elif self.__const is not None:
return self.__const
else:
x = self.__rv1.observe()
if self.__rv2 is not None:
y = self.__rv2.observe()
return self.__op(x, y)
else:
return self.__op(x)
def __radd__(self, other):
if not isinstance(other, RV):
other = RV(other)
return RV((other, lambda x, y: x + y, self))
def __add__(self, other):
if not isinstance(other, RV):
other = RV(other)
return RV((self, lambda x, y: x + y, other))
def __rsub__(self, other):
if not isinstance(other, RV):
other = RV(other)
return RV((other, lambda x, y: x - y, self))
def __sub__(self, other):
if not isinstance(other, RV):
other = RV(other)
return RV((self, lambda x, y: x - y, other))
def __mul__(self, other):
if not isinstance(other, RV):
other = RV(other)
return RV((self, lambda x, y: x * y, other))
def __rmul__(self, other):
if not isinstance(other, RV):
other = RV(other)
return RV((other, lambda x, y: x * y, self))
def __truediv__(self, other):
if not isinstance(other, RV):
other = RV(other)
if 0 in other.domain():
raise ZeroDivisionError(f'division by zero (0 in domain of denominator)')
return RV((self, lambda x, y: x / y, other))
def __pow__(self, other):
if not isinstance(other, RV):
other = RV(other)
return RV((self, lambda x, y: x ** y, other))
def __rpow__(self, other):
if not isinstance(other, RV):
other = RV(other)
return RV((other, lambda x, y: x ** y, self))
def __neg__(self):
return RV((self, lambda x: -x, None))
def __abs__(self):
return RV((self, lambda x: abs(x), None))
def __or__(self, other):
if not isinstance(other, RV):
other = RV(other)
return RV((self, lambda x, y: x or y, other))
def __ror__(self, other):
if not isinstance(other, RV):
other = RV(other)
return RV((other, lambda x, y: x or y, self))
def __and__(self, other):
if not isinstance(other, RV):
other = RV(other)
return RV((self, lambda x, y: x and y, other))
def __rand__(self, other):
if not isinstance(other, RV):
other = RV(other)
return RV((other, lambda x, y: x and y, self))
def __xor__(self, other):
if not isinstance(other, RV):
other = RV(other)
return RV((self, lambda x, y: x ^ y, other))
def __invert__(self):
return RV((self, lambda x: ~x, None))
def __eq__(self, other):
if not isinstance(other, RV):
other = RV(other)
return RV((self, lambda x, y: x == y, other))
def __ne__(self, other):
if not isinstance(other, RV):
other = RV(other)
return RV((self, lambda x, y: x != y, other))
def __lt__(self, other):
if not isinstance(other, RV):
other = RV(other)
return RV((self, lambda x, y: x < y, other))
def __gt__(self, other):
if not isinstance(other, RV):
other = RV(other)
return RV((self, lambda x, y: x > y, other))
def __le__(self, other):
if not isinstance(other, RV):
other = RV(other)
return RV((self, lambda x, y: x <= y, other))
def __ge__(self, other):
if not isinstance(other, RV):
other = RV(other)
return RV((self, lambda x, y: x >= y, other))
def __round__(self, ndigits=None):
return RV((self, lambda x: round(x, ndigits), None))
def __floor__(self):
return RV((self, lambda x: floor(x), None))
def __ceil__(self):
return RV((self, lambda x: ceil(x), None))
def parents(self):
"""Returns the immediate parents of this RV."""
if self.__rv1 and self.__rv2:
return (self.__rv1, self.__rv2)
elif self.__rv1:
return {self.__rv1}
return None
def __domain(self):
assert self.__probs is not None or self.__const is not None
if self.__probs is not None:
return self.__probs
self.__probs = {self.__const: 1}
return self.__probs
def domain(self):
"""Returns the domain, the set of possible values."""
return set(self.dist())
def dist(self):
"""Returns the probability distribution."""
if self.__probs is None:
# for all root variables
# set an assignment, prob(assignment) = prod(prob(x_i(assignment)))
def update_func(domain_values, rvs, ds, counters, counter):
testval = self.__test_observed()
if testval not in domain_values:
domain_values[testval] = 0
domain_values[testval] += reduce(lambda a, b: a*b, [rvs[i].__domain()[ds[i][counters[i]]] for i in range(len(counters))])
self.__probs = self.__compute_dependencies({}, update_func)
return self.__probs
d1, d2 = p1.dist(), p2.dist()
return self.__probs
else:
return self.__domain()
def exp(self):
"""Computes the expected value."""
return sum([p*v for v, p in self.dist().items()])
def var(self):
"""Computes the variance."""
mu = self.exp()
return sum([p*(v-mu)**2 for v, p in self.dist().items()])
def std(self):
"""Computes the standard deviation."""
return sqrt(self.var())
def __test_observed(self):
if self.__rand is not None:
return self.__testval
elif self.__const is not None:
return self.__const
else:
x = self.__rv1.__test_observed()
if self.__rv2 is not None:
y = self.__rv2.__test_observed()
return self.__op(x, y)
else:
return self.__op(x)
def __set_testval(self, val):
assert val in self.__domain()
self.__testval = val
def __dependencies(self):
root_variables = set()
frontier = {self}
while frontier:
x = frontier.pop()
if x.parents():
for p in x.parents():
frontier.add(p)
else:
root_variables.add(x)
return root_variables
def __compute_dependencies(self, initial_domain, update_func):
root_variables = list(self.__dependencies())
assert root_variables
domains = [list(rv.__domain()) for rv in root_variables]
counters = [0 for rv in root_variables]
# set a dummy value
rv_l = len(root_variables)
domain_values = initial_domain
# initialize dummy values
for i in range(0, rv_l):
root_variables[i].__set_testval(domains[i][counters[i]])
counter = 0
while counters[rv_l-1] < len(domains[rv_l-1]):
if counters[counter] == len(domains[counter]):
counters[counter] = 0
root_variables[counter].__set_testval(domains[counter][counters[counter]])
counters[counter+1] += 1
if counters[counter+1] < len(domains[counter+1]):
root_variables[counter+1].__set_testval(domains[counter+1][counters[counter+1]])
counter = counter + 1
continue
counter = 0
root_variables[counter].__set_testval(domains[counter][counters[counter]])
update_func(domain_values, root_variables, domains, counters, counter)
counters[counter] += 1
return domain_values
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment