Created
March 29, 2016 18:57
-
-
Save tfmorris/0f3878a6fa1c91dc6787aee06f55cac8 to your computer and use it in GitHub Desktop.
Online variation and standard deviation using Welford's algorithm and Java 8 Streams - just a sketch! only lightly tested!!
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.Collections; | |
import java.util.EnumSet; | |
import java.util.IntSummaryStatistics; | |
import java.util.Set; | |
import java.util.function.BiConsumer; | |
import java.util.function.BinaryOperator; | |
import java.util.function.Function; | |
import java.util.function.Supplier; | |
import java.util.function.ToIntFunction; | |
import java.util.stream.Collector; | |
import java.util.stream.Collectors; | |
/** | |
* Online accumulator which extends the Java 8 IntSummaryStatistics class to | |
* also do variance and standard deviation using Welford's algorithm. | |
* | |
* @author Tom Morris <[email protected]> | |
* | |
*/ | |
public class IntAccumulator extends IntSummaryStatistics { | |
private double mean = 0.0; // our online mean estimate | |
private double m2 = 0.0; | |
@Override | |
public void accept(int value) { | |
super.accept(value); | |
double delta = value - mean; | |
mean += delta / this.getCount(); // getCount() too inefficient? | |
m2 += delta * (value - mean); | |
} | |
@Override | |
public void combine(IntSummaryStatistics other) { | |
// TODO: What's the right answer here? Just throw or attempt to cast? | |
combine((IntAccumulator) other); | |
}; | |
public void combine(IntAccumulator other) { | |
long count = getCount(); // get the old count before we combined | |
long otherCount = other.getCount(); | |
double totalCount = count + otherCount; | |
super.combine(other); | |
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm | |
double delta = other.getMeanEstimate() - mean; | |
// mean += delta * (otherCount / totalCount); | |
mean = (mean * count + other.getMeanEstimate() * otherCount) / totalCount; | |
m2 += other.getSquareSum() + ((delta * delta) * count * otherCount / totalCount); | |
} | |
private double getSquareSum() { | |
return m2; | |
} | |
/** | |
* Returns the online version of the mean which may be less accurate, but | |
* won't overflow like the version kept by {@link #getAverage()}. | |
* | |
* @return | |
*/ | |
public double getMeanEstimate() { | |
return mean; | |
} | |
public double getSampleVariance() { | |
long count = getCount(); | |
if (count < 2) { | |
return 0.0; | |
} else { | |
return m2 / (getCount() - 1); // sample variance N-1 | |
} | |
} | |
public double getSampleStdDev() { | |
return Math.sqrt(getSampleVariance()); | |
} | |
public static <T> Collector<T, ?, IntAccumulator> summarizingIntStdDev(ToIntFunction<? super T> mapper) { | |
return new CollectorImpl<T, IntAccumulator, IntAccumulator>(IntAccumulator::new, | |
(r, t) -> r.accept(mapper.applyAsInt(t)), (l, r) -> { | |
l.combine(r); | |
return l; | |
}, CollectorImpl.CH_ID); | |
} | |
} | |
/** | |
* Private copy of {@link Collectors.CollectorImpl} that we can use to get | |
* around visibility restrictions. | |
* | |
* @param <T> | |
* @param <A> | |
* @param <R> | |
*/ | |
class CollectorImpl<T, A, R> implements Collector<T, A, R> { | |
static final Set<Collector.Characteristics> CH_ID = Collections | |
.unmodifiableSet(EnumSet.of(Collector.Characteristics.IDENTITY_FINISH)); | |
private final Supplier<A> supplier; | |
private final BiConsumer<A, T> accumulator; | |
private final BinaryOperator<A> combiner; | |
private final Function<A, R> finisher; | |
private final Set<Characteristics> characteristics; | |
CollectorImpl(Supplier<A> supplier, BiConsumer<A, T> accumulator, BinaryOperator<A> combiner, | |
Function<A, R> finisher, Set<Characteristics> characteristics) { | |
this.supplier = supplier; | |
this.accumulator = accumulator; | |
this.combiner = combiner; | |
this.finisher = finisher; | |
this.characteristics = characteristics; | |
} | |
CollectorImpl(Supplier<A> supplier, BiConsumer<A, T> accumulator, BinaryOperator<A> combiner, | |
Set<Characteristics> characteristics) { | |
this(supplier, accumulator, combiner, castingIdentity(), characteristics); | |
} | |
@Override | |
public BiConsumer<A, T> accumulator() { | |
return accumulator; | |
} | |
@Override | |
public Supplier<A> supplier() { | |
return supplier; | |
} | |
@Override | |
public BinaryOperator<A> combiner() { | |
return combiner; | |
} | |
@Override | |
public Function<A, R> finisher() { | |
return finisher; | |
} | |
@Override | |
public Set<Characteristics> characteristics() { | |
return characteristics; | |
} | |
@SuppressWarnings("unchecked") | |
private static <I, R> Function<I, R> castingIdentity() { | |
return i -> (R) i; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment