Created
December 9, 2019 23:22
-
-
Save Prince781/d250d2e449031f158c79879bc33b03cf to your computer and use it in GitHub Desktop.
discrete random variable mini-library
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import random | |
| from functools import reduce | |
| import numbers | |
| from math import sqrt, floor, ceil | |
| # RV1 + RV2 = new RV | |
| # same idea for other arithmetic ops | |
| class RV: | |
| def __init__(self, values): | |
| """Creates a new discrete random variable.""" | |
| assert values is not None | |
| assert isinstance(values, dict) or isinstance(values, numbers.Real) \ | |
| or isinstance(values, tuple) | |
| if isinstance(values, dict): | |
| assert values | |
| assert sum(values.values()) == 1.0 | |
| self.__probs = values | |
| self.__rand = random.Random() | |
| self.__state = self.__rand.getstate() | |
| self.__const = None | |
| self.__op = None | |
| self.__rv1 = None | |
| self.__rv2 = None | |
| self.__testval = self.observe() | |
| elif isinstance(values, numbers.Real): | |
| self.__probs = None | |
| self.__rand = None | |
| self.__state = None | |
| self.__const = values | |
| self.__op = None | |
| self.__rv1 = None | |
| self.__rv2 = None | |
| self.__testval = self.observe() | |
| else: | |
| # assert isinstance(values, tuple) | |
| assert len(values) == 3 | |
| rv1, op, rv2 = values | |
| assert isinstance(rv1, RV) | |
| assert isinstance(rv2, RV) or rv2 is None | |
| assert op | |
| self.__probs = None | |
| self.__rand = None | |
| self.__state = None | |
| self.__const = None | |
| self.__op = op | |
| self.__rv1 = rv1 | |
| self.__rv2 = rv2 | |
| self.__testval = self.observe() | |
| def __hash__(self): | |
| return id(self) | |
| def sample(self): | |
| """Samples this random variable.""" | |
| if self.__rand is not None: | |
| self.__state = self.__rand.getstate() | |
| r = self.__rand.random() | |
| for val, prob in self.__probs.items(): | |
| if r <= prob: | |
| return val | |
| r = r - prob | |
| assert False, "Should not be reached" | |
| elif self.__const is not None: | |
| return self.__const | |
| else: | |
| for d in self.__dependencies(): | |
| d.sample() | |
| x = self.__rv1.observe() | |
| if self.__rv2 is not None: | |
| y = self.__rv2.observe() | |
| return self.__op(x, y) | |
| else: | |
| return self.__op(x) | |
| def observe(self): | |
| """Observes the current value without sampling.""" | |
| if self.__rand is not None: | |
| self.__rand.setstate(self.__state) | |
| r = self.__rand.random() | |
| for val, prob in self.__probs.items(): | |
| if r <= prob: | |
| return val | |
| r = r - prob | |
| assert False, "Should not be reached" | |
| elif self.__const is not None: | |
| return self.__const | |
| else: | |
| x = self.__rv1.observe() | |
| if self.__rv2 is not None: | |
| y = self.__rv2.observe() | |
| return self.__op(x, y) | |
| else: | |
| return self.__op(x) | |
| def __radd__(self, other): | |
| if not isinstance(other, RV): | |
| other = RV(other) | |
| return RV((other, lambda x, y: x + y, self)) | |
| def __add__(self, other): | |
| if not isinstance(other, RV): | |
| other = RV(other) | |
| return RV((self, lambda x, y: x + y, other)) | |
| def __rsub__(self, other): | |
| if not isinstance(other, RV): | |
| other = RV(other) | |
| return RV((other, lambda x, y: x - y, self)) | |
| def __sub__(self, other): | |
| if not isinstance(other, RV): | |
| other = RV(other) | |
| return RV((self, lambda x, y: x - y, other)) | |
| def __mul__(self, other): | |
| if not isinstance(other, RV): | |
| other = RV(other) | |
| return RV((self, lambda x, y: x * y, other)) | |
| def __rmul__(self, other): | |
| if not isinstance(other, RV): | |
| other = RV(other) | |
| return RV((other, lambda x, y: x * y, self)) | |
| def __truediv__(self, other): | |
| if not isinstance(other, RV): | |
| other = RV(other) | |
| if 0 in other.domain(): | |
| raise ZeroDivisionError(f'division by zero (0 in domain of denominator)') | |
| return RV((self, lambda x, y: x / y, other)) | |
| def __pow__(self, other): | |
| if not isinstance(other, RV): | |
| other = RV(other) | |
| return RV((self, lambda x, y: x ** y, other)) | |
| def __rpow__(self, other): | |
| if not isinstance(other, RV): | |
| other = RV(other) | |
| return RV((other, lambda x, y: x ** y, self)) | |
| def __neg__(self): | |
| return RV((self, lambda x: -x, None)) | |
| def __abs__(self): | |
| return RV((self, lambda x: abs(x), None)) | |
| def __or__(self, other): | |
| if not isinstance(other, RV): | |
| other = RV(other) | |
| return RV((self, lambda x, y: x or y, other)) | |
| def __ror__(self, other): | |
| if not isinstance(other, RV): | |
| other = RV(other) | |
| return RV((other, lambda x, y: x or y, self)) | |
| def __and__(self, other): | |
| if not isinstance(other, RV): | |
| other = RV(other) | |
| return RV((self, lambda x, y: x and y, other)) | |
| def __rand__(self, other): | |
| if not isinstance(other, RV): | |
| other = RV(other) | |
| return RV((other, lambda x, y: x and y, self)) | |
| def __xor__(self, other): | |
| if not isinstance(other, RV): | |
| other = RV(other) | |
| return RV((self, lambda x, y: x ^ y, other)) | |
| def __invert__(self): | |
| return RV((self, lambda x: ~x, None)) | |
| def __eq__(self, other): | |
| if not isinstance(other, RV): | |
| other = RV(other) | |
| return RV((self, lambda x, y: x == y, other)) | |
| def __ne__(self, other): | |
| if not isinstance(other, RV): | |
| other = RV(other) | |
| return RV((self, lambda x, y: x != y, other)) | |
| def __lt__(self, other): | |
| if not isinstance(other, RV): | |
| other = RV(other) | |
| return RV((self, lambda x, y: x < y, other)) | |
| def __gt__(self, other): | |
| if not isinstance(other, RV): | |
| other = RV(other) | |
| return RV((self, lambda x, y: x > y, other)) | |
| def __le__(self, other): | |
| if not isinstance(other, RV): | |
| other = RV(other) | |
| return RV((self, lambda x, y: x <= y, other)) | |
| def __ge__(self, other): | |
| if not isinstance(other, RV): | |
| other = RV(other) | |
| return RV((self, lambda x, y: x >= y, other)) | |
| def __round__(self, ndigits=None): | |
| return RV((self, lambda x: round(x, ndigits), None)) | |
| def __floor__(self): | |
| return RV((self, lambda x: floor(x), None)) | |
| def __ceil__(self): | |
| return RV((self, lambda x: ceil(x), None)) | |
| def parents(self): | |
| """Returns the immediate parents of this RV.""" | |
| if self.__rv1 and self.__rv2: | |
| return (self.__rv1, self.__rv2) | |
| elif self.__rv1: | |
| return {self.__rv1} | |
| return None | |
| def __domain(self): | |
| assert self.__probs is not None or self.__const is not None | |
| if self.__probs is not None: | |
| return self.__probs | |
| self.__probs = {self.__const: 1} | |
| return self.__probs | |
| def domain(self): | |
| """Returns the domain, the set of possible values.""" | |
| return set(self.dist()) | |
| def dist(self): | |
| """Returns the probability distribution.""" | |
| if self.__probs is None: | |
| # for all root variables | |
| # set an assignment, prob(assignment) = prod(prob(x_i(assignment))) | |
| def update_func(domain_values, rvs, ds, counters, counter): | |
| testval = self.__test_observed() | |
| if testval not in domain_values: | |
| domain_values[testval] = 0 | |
| domain_values[testval] += reduce(lambda a, b: a*b, [rvs[i].__domain()[ds[i][counters[i]]] for i in range(len(counters))]) | |
| self.__probs = self.__compute_dependencies({}, update_func) | |
| return self.__probs | |
| d1, d2 = p1.dist(), p2.dist() | |
| return self.__probs | |
| else: | |
| return self.__domain() | |
| def exp(self): | |
| """Computes the expected value.""" | |
| return sum([p*v for v, p in self.dist().items()]) | |
| def var(self): | |
| """Computes the variance.""" | |
| mu = self.exp() | |
| return sum([p*(v-mu)**2 for v, p in self.dist().items()]) | |
| def std(self): | |
| """Computes the standard deviation.""" | |
| return sqrt(self.var()) | |
| def __test_observed(self): | |
| if self.__rand is not None: | |
| return self.__testval | |
| elif self.__const is not None: | |
| return self.__const | |
| else: | |
| x = self.__rv1.__test_observed() | |
| if self.__rv2 is not None: | |
| y = self.__rv2.__test_observed() | |
| return self.__op(x, y) | |
| else: | |
| return self.__op(x) | |
| def __set_testval(self, val): | |
| assert val in self.__domain() | |
| self.__testval = val | |
| def __dependencies(self): | |
| root_variables = set() | |
| frontier = {self} | |
| while frontier: | |
| x = frontier.pop() | |
| if x.parents(): | |
| for p in x.parents(): | |
| frontier.add(p) | |
| else: | |
| root_variables.add(x) | |
| return root_variables | |
| def __compute_dependencies(self, initial_domain, update_func): | |
| root_variables = list(self.__dependencies()) | |
| assert root_variables | |
| domains = [list(rv.__domain()) for rv in root_variables] | |
| counters = [0 for rv in root_variables] | |
| # set a dummy value | |
| rv_l = len(root_variables) | |
| domain_values = initial_domain | |
| # initialize dummy values | |
| for i in range(0, rv_l): | |
| root_variables[i].__set_testval(domains[i][counters[i]]) | |
| counter = 0 | |
| while counters[rv_l-1] < len(domains[rv_l-1]): | |
| if counters[counter] == len(domains[counter]): | |
| counters[counter] = 0 | |
| root_variables[counter].__set_testval(domains[counter][counters[counter]]) | |
| counters[counter+1] += 1 | |
| if counters[counter+1] < len(domains[counter+1]): | |
| root_variables[counter+1].__set_testval(domains[counter+1][counters[counter+1]]) | |
| counter = counter + 1 | |
| continue | |
| counter = 0 | |
| root_variables[counter].__set_testval(domains[counter][counters[counter]]) | |
| update_func(domain_values, root_variables, domains, counters, counter) | |
| counters[counter] += 1 | |
| return domain_values |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment