Created
January 4, 2013 23:00
-
-
Save wolever/4458294 to your computer and use it in GitHub Desktop.
A simple Postgres aggregate function for calculating a trimmed mean, excluding values outside N standard deviations from the mean: `tmean(v, standard_deviations)` (for example: `tmean(rating, 1.75)`).
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
DROP TABLE IF EXISTS foo; | |
CREATE TEMPORARY TABLE foo (x FLOAT); | |
INSERT INTO foo VALUES (1); | |
INSERT INTO foo VALUES (2); | |
INSERT INTO foo VALUES (3); | |
INSERT INTO foo VALUES (4); | |
INSERT INTO foo VALUES (100); | |
SELECT avg(x), tmean(x, 2.0), tmean(x, 1.5) FROM foo; | |
-- avg | tmean | tmean | |
-- -----+-------+------- | |
-- 22 | 22 | 2.5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
DROP TYPE IF EXISTS tmean_stype CASCADE; | |
CREATE TYPE tmean_stype AS ( | |
deviations FLOAT, | |
count INT, | |
acc FLOAT, | |
acc2 FLOAT, | |
vals FLOAT[] | |
); | |
CREATE OR REPLACE FUNCTION tmean_sfunc(tmean_stype, float, float) | |
RETURNS tmean_stype AS $$ | |
SELECT $3, $1.count + 1, $1.acc + $2, $1.acc2 + ($2 * $2), array_append($1.vals, $2); | |
$$ LANGUAGE SQL; | |
CREATE OR REPLACE FUNCTION tmean_finalfunc(tmean_stype) | |
RETURNS float AS $$ | |
DECLARE | |
fcount INT; | |
facc FLOAT; | |
mean FLOAT; | |
stddev FLOAT; | |
lbound FLOAT; | |
ubound FLOAT; | |
val FLOAT; | |
BEGIN | |
mean := $1.acc / $1.count; | |
stddev := sqrt(($1.acc2 / $1.count) - (mean * mean)); | |
lbound := mean - stddev * $1.deviations; | |
ubound := mean + stddev * $1.deviations; | |
-- RAISE NOTICE 'mean: % stddev: % lbound: % ubound: %', mean, stddev, lbound, ubound; | |
fcount := 0; | |
facc := 0; | |
FOR i IN array_lower($1.vals, 1) .. array_upper($1.vals, 1) LOOP | |
val := $1.vals[i]; | |
IF val >= lbound AND val <= ubound THEN | |
fcount := fcount + 1; | |
facc := facc + val; | |
END IF; | |
END LOOP; | |
IF fcount = 0 THEN | |
return NULL; | |
END IF; | |
RETURN facc / fcount; | |
END; | |
$$ LANGUAGE plpgsql; | |
CREATE AGGREGATE tmean(float, float) | |
( | |
SFUNC = tmean_sfunc, | |
STYPE = tmean_stype, | |
FINALFUNC = tmean_finalfunc, | |
INITCOND = '(-1, 0, 0, 0, {})' | |
); |
You can run into division by zero
error if there's no data available. To prevent this, add this right after BEGIN
IF $1.count = 0 THEN
return NULL;
END IF;
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hello David Wolever. I read your example code to calculate the trimmed mean and for me it has been very useful. But, if I want to calculate it to 20% or 10%?. How could I do it?