Last active
December 16, 2015 14:19
-
-
Save zmaril/5447488 to your computer and use it in GitHub Desktop.
Simple rejection sampling based probabilistic programming library. An approximation of church.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(ns hacklheber.core) | |
(defn flip | |
"A function which returns true or false randomly. Can optionally be | |
supplied a number for a bias." | |
([] (> 0.5 (rand))) | |
([p] (> p (rand)))) | |
(defn- memo-bangs | |
"If a variable is bound with a bang, then it will be memoized." | |
[[k v]] | |
(if (= \! (last (name k))) | |
[k `(memoize ~v)] | |
[k v])) | |
(defn- find-clause | |
"Given a list of clauses and a key, this finds the body of the first | |
clause which has the same keyword as the given key." | |
[clauses key] | |
(->> clauses | |
(filter (fn [[k v]] (= key k))) | |
first | |
second)) | |
(defmacro query-by-rejection | |
"Query a distribution via rejection. See further below for | |
examples." | |
[[& bindings] & body-expr] | |
(let [{clauses# true pairs# false} (group-by (comp keyword? first) | |
(partition-all 2 bindings)) | |
where# (or (find-clause clauses# :where) | |
'(fn [] true)) | |
memobound# (->> (find-clause clauses# :memobound) | |
(mapcat (juxt identity (fn [v] `(memoize ~v)))) | |
vec) | |
pairs# (vec (mapcat memo-bangs pairs#))] | |
`(loop [] | |
(let [[cond# result#] (binding ~memobound# | |
(let [~@pairs# cond# ~where#] | |
[cond# (when cond# (do ~@body-expr))]))] | |
(if cond# | |
result# | |
(recur)))))) | |
(defn normalized-frequencies | |
"Takes in a collection and computes the normalizied frequenicies of | |
elements in the colleciton." | |
[col] | |
(let [freqs (frequencies col) | |
count (reduce + (map second freqs)) | |
normalized (for [[k v] freqs] [k (float (/ v count))])] | |
(into {} normalized))) | |
(defmacro sample-by-rejection | |
"Takes in a number n and the body for a query-by-rejection. Executes | |
the specified query n times." | |
[n & body] | |
`(for [i# (range ~n)] | |
(query-by-rejection ~@body))) | |
;;Example queries taken from Church | |
;;http://projects.csail.mit.edu/church/wiki/Conditioning | |
(defn ^{:dynamic true | |
:doc "Taken from the Church examples."} | |
eye-color | |
[person] | |
(rand-nth '(blue green brown))) | |
;;Persistent randomized functions | |
(query-by-rejection | |
[bob-1 (eye-color :bob) | |
alice-1 (eye-color :alice) | |
bob-2 (eye-color :bob) | |
:where (flip 0.01) | |
:memobound [eye-color]] | |
[bob-1 alice-1 bob-2]) | |
;;A complex query | |
(defn complex-samples [] | |
(sample-by-rejection | |
10000 | |
[works-in-hospital (flip 0.01) | |
smokes (flip 0.2) | |
lung-cancer (or (flip 0.01) | |
(and smokes (flip 0.02))) | |
TB (or (flip 0.005) | |
(and works-in-hospital (flip 0.01))) | |
cold (or (flip 0.2) | |
(and works-in-hospital (flip 0.25))) | |
stomach-flu (flip 0.1) | |
other (flip 0.1) | |
cough (or (and cold (flip 0.5)) | |
(and lung-cancer (flip 0.3)) | |
(and TB (flip 0.7)) | |
(and other (flip 0.01))) | |
fever (or (and cold (flip 0.3)) | |
(and stomach-flu (flip 0.5)) | |
(and TB (flip 0.2)) | |
(and other (flip 0.01))) | |
chest-pain (or (and lung-cancer (flip 0.4)) | |
(and TB (flip 0.5)) | |
(and other( flip 0.01))) | |
shortness-of-breath (or (and lung-cancer (flip 0.4)) | |
(and TB (flip 0.5)) | |
(and other (flip 0.01))) | |
:where (and cough chest-pain shortness-of-breath)] | |
(list lung-cancer TB))) | |
(defn ^{:dynamic true} strength | |
[person] | |
(if (flip) 10 5)) | |
(defn lazy | |
[person] | |
(flip (/ 1 3))) | |
(defn contribution | |
[person] | |
(if (lazy person) | |
(/ (strength person) 2) | |
(strength person))) | |
(defn total-pulling | |
[team] | |
(->> team | |
(map contribution) | |
(reduce +))) | |
(defn winner [team1 team2] | |
(if (< (total-pulling team1) (total-pulling team2)) | |
'team2 'team1)) | |
;;Using persistent values defined outside the query | |
(defn tug-of-war-sample [] | |
(sample-by-rejection | |
10000 | |
[:memobound [strength] | |
:where (and (= 'team1 (winner '(bob mary) '(tom sue))) | |
(= 'team1 (winner '(bob sue) '(tom jim))))] | |
(strength 'bob))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment