Created
March 25, 2016 19:10
-
-
Save klgraham/c1bc8fb6accb97e5aa6f to your computer and use it in GitHub Desktop.
Probability monad in Java 8
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* How to use Java 8 collections to create the probability monad. | |
* | |
* Created by klgraham on 10/25/15. | |
*/ | |
import java.util.*; | |
import java.util.function.Function; | |
import java.util.function.Predicate; | |
import java.util.stream.Collectors; | |
import java.util.stream.IntStream; | |
import java.util.stream.Stream; | |
public class Distributions | |
{ | |
/** | |
* Using streams | |
* | |
* A stream is a sequence of elements. We can convert a standard Java collection | |
* into a stream and then perform various operations on the stream. | |
*/ | |
public static void main(String[] args) | |
{ | |
System.out.println("Sampling 10 doubles"); | |
uniform.sample(10).forEach(System.out::println); | |
System.out.println("\nSampling 10 mapped doubles"); | |
uniform.map(p -> p * 2).sample(10).forEach(System.out::println); | |
System.out.println("\nSampling from uniform boolean distribution"); | |
tf(0.5).sample(10).forEach(System.out::println); | |
System.out.println("\nSampling from uniform bernoulli distribution, with p = 0.7"); | |
bernoulli(0.7).sample(10).forEach(System.out::println); | |
System.out.println("\nProbability of a uniform double being less than 0.5:"); | |
System.out.println(uniform.prob(u -> u < 0.5)); | |
System.out.println("\nSample 10 Uniform variables above 0.3:"); | |
uniform.given(u -> u > 0.3).sample(10).forEach(System.out::println); | |
System.out.println("\nGet 3 lists of 3 uniform variables"); | |
uniform.repeat(3).sample(3).forEach(System.out::println); | |
System.out.println("\n6-sided die"); | |
Distribution<Integer> die6 = discreteUniform(Arrays.asList(1, 2, 3, 4, 5, 6)); | |
System.out.println(die6.sample(10)); | |
System.out.println("Prob(3) = " + die6.prob(p -> p == 3)); | |
System.out.println("\nPair of 6-sided dice"); | |
Distribution<Integer> dice = die6.repeat(2).map(p -> p.get(0) + p.get(1)); | |
System.out.println("Prob(7) = " + dice.prob(p -> p == 7)); | |
System.out.println("Prob(11) = " + dice.prob(p -> p == 11)); | |
System.out.println("Prob(4) = " + dice.prob(p -> p == 4)); | |
System.out.println("\nPair of 6-sided dice, via flatmap"); | |
Distribution<Integer> dice1 = die6.flatMap(d1 -> die6.map(d2 -> d1 + d2)); | |
System.out.println("Prob(7) = " + dice1.prob(p -> p == 7)); | |
System.out.println("Prob(11) = " + dice1.prob(p -> p == 11)); | |
System.out.println("Prob(4) = " + dice1.prob(p -> p == 4)); | |
System.out.println("\nMonty Hall problem"); | |
System.out.println("Prob. that switching doors finds the prize: " + | |
montyHall().prob(pair -> pair._1 == pair._2)); | |
System.out.println("\nNormal Distribution"); | |
System.out.println("Mean: " + normal().mean()); | |
System.out.println("StdDev: " + normal().stdDev()); | |
} | |
// public Distributions() { | |
// } | |
/** | |
* Uniform distribution [0, 1] | |
*/ | |
static Distribution<Double> uniform = new Distribution<Double>() { | |
private Random r = new Random(); | |
@Override | |
Double get() | |
{ | |
return r.nextDouble(); | |
} | |
}; | |
/** | |
* Boolean distribution | |
* @param p probability of true | |
* @return | |
*/ | |
static Distribution<Boolean> tf(double p) { | |
return uniform.map(n -> n < p); | |
} | |
/** | |
* Bernoulli distribution | |
* 1 is success or a hit and 0 is failure or a miss | |
* @param p probability of 1 | |
* @return distribution of 1s and 0s | |
*/ | |
static Distribution<Integer> bernoulli(double p) { | |
return tf(p).map(b -> b ? 1 : 0); | |
} | |
static Distribution<Double> normal() { | |
return new Distribution<Double>() { | |
private Random r = new Random(); | |
@Override | |
Double get() | |
{ | |
return r.nextGaussian(); | |
} | |
}; | |
} | |
/** | |
* Discrete distribution | |
* @param values random values the distribution can take | |
* @param <A> | |
* @return | |
*/ | |
static <A> Distribution<A> discreteUniform(Collection<A> values) { | |
List<A> vec = new ArrayList<A>(values); | |
return uniform.map(x -> vec.get((int) (x * vec.size()))); | |
} | |
static Distribution<Integer> removePriceAndChoice(Set<Integer> doors, int p, int c) { | |
Set<Integer> d = new HashSet<Integer>(doors); | |
d.remove(p); | |
d.remove(c); | |
return discreteUniform(d); | |
} | |
static Distribution<Tuple<Integer, Integer>> montyHall() | |
{ | |
Set<Integer> doors = new HashSet<>(); | |
doors.addAll(Arrays.asList(1, 2, 3)); | |
Distribution<Integer> prize = discreteUniform(doors); | |
Distribution<Integer> choice = discreteUniform(doors); | |
Distribution<Tuple<Integer, Integer>> mh = | |
prize.flatMap(p -> choice. // random prize location | |
flatMap(c -> removePriceAndChoice(doors, p, c). // random choice | |
flatMap(o -> removePriceAndChoice(doors, c, o). // open one of other doors | |
map(s -> new Tuple<Integer, Integer>(p, s))))); // switch | |
return mh; | |
} | |
} | |
/** | |
* Probability distribution | |
* @param <A> type of the random variable | |
*/ | |
abstract class Distribution<A> | |
{ | |
/** | |
* Choose a random variable of type A | |
* @return | |
*/ | |
abstract A get(); | |
/** | |
* Generate a list of n random variables of type A | |
* @param n number of random variables | |
* @return | |
*/ | |
List<A> sample(Integer n) | |
{ | |
return Collections.nCopies(n, 0).stream().map(p -> this.get()).collect(Collectors.toList()); | |
} | |
/** | |
* Maps one Distribution into another | |
* @param f mapping function | |
* @param <B> type of variables in output distribution | |
* @return | |
*/ | |
<B> Distribution<B> map(Function<A, B> f) | |
{ | |
Distribution<A> dist = this; | |
return new Distribution<B>() | |
{ | |
@Override | |
B get() | |
{ | |
return f.apply(dist.get()); | |
} | |
}; | |
} | |
/** | |
* FlatMaps one Distribution into another | |
* @param f function mapping one value to a distribution | |
* @param <B> type of variables in output distribution | |
* @return | |
*/ | |
<B> Distribution<B> flatMap(Function<A, Distribution<B>> f) | |
{ | |
Distribution<A> dist = this; | |
return new Distribution<B>() | |
{ | |
@Override | |
B get() | |
{ | |
return f.apply(dist.get()).get(); | |
} | |
}; | |
} | |
private int N = 10000; | |
/** | |
* Probability of the predicate being true | |
* @param predicate | |
* @return | |
*/ | |
double prob(Predicate<A> predicate) | |
{ | |
return (double)this.sample(N).stream().filter(predicate).count() / (double)N; | |
} | |
/** | |
* Samples from the new distribution so that the result matches the predicate | |
* @param predicate | |
* @return | |
*/ | |
Distribution<A> given(Predicate<A> predicate) | |
{ | |
Distribution<A> dist = this; | |
return new Distribution<A>() { | |
A a = dist.get(); | |
@Override | |
A get() { | |
return predicate.test(a) ? a : dist.get(); | |
} | |
}; | |
} | |
/** | |
* Creates a distribution of lists of samples of length n | |
* @param n | |
* @return | |
*/ | |
Distribution<List<A>> repeat(int n) | |
{ | |
Distribution<A> dist = this; | |
return new Distribution<List<A>>() { | |
@Override | |
List<A> get() | |
{ | |
return dist.sample(n); | |
} | |
}; | |
} | |
double mean() | |
{ | |
double sum = 0; | |
for (A v : this.sample(N)) | |
{ | |
sum += Double.valueOf(v.toString()); | |
} | |
return sum / (double)N; | |
} | |
double variance() | |
{ | |
double sum = 0; | |
double sqrSum = 0; | |
for (A v : this.sample(N)) | |
{ | |
double vv = Double.valueOf(v.toString()); | |
sum += vv; | |
sqrSum += vv * vv; | |
} | |
return (sqrSum - sum * sum / (double)N) / (double)(N-1); | |
} | |
double stdDev() | |
{ | |
return Math.sqrt(this.variance()); | |
} | |
} | |
class Tuple<T, U> | |
{ | |
public final T _1; | |
public final U _2; | |
public Tuple(T arg1, U arg2) { | |
super(); | |
this._1 = arg1; | |
this._2 = arg2; | |
} | |
@Override | |
public String toString() { | |
return String.format("(%s, %s)", _1, _2); | |
} | |
@Override | |
public boolean equals(Object o) { | |
if (this == o) return true; | |
if (o == null || getClass() != o.getClass()) return false; | |
Tuple<?, ?> tuple = (Tuple<?, ?>) o; | |
if (!_1.equals(tuple._1)) return false; | |
return _2.equals(tuple._2); | |
} | |
@Override | |
public int hashCode() { | |
int result = _1.hashCode(); | |
result = 31 * result + _2.hashCode(); | |
return result; | |
} | |
} | |
//class Histogram | |
//{ | |
// public static <T> Map<T, Long> frequencies(Stream<T> stream) | |
// { | |
// return stream. | |
// collect(Collectors.groupingBy(Function.identity(), Collectors.counting())); | |
// } | |
// | |
// public static <T> Map<T, Double> histogram(Stream<T> stream) | |
// { | |
// int N = (int)stream.count(); | |
// Map<T, Double> hist = new HashMap<>(); | |
// Map<T, Long> freqs = frequencies(stream); | |
// | |
// for (Map.Entry<T, Long> entry : freqs.entrySet()) | |
// { | |
// hist.put(entry.getKey(), (double)entry.getValue() / (double)N); | |
// } | |
// | |
// return hist; | |
// } | |
//} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment