Last active
July 21, 2017 18:43
-
-
Save kleinschmidt/52cee5977addde62bdf5a3b5bdc8068a to your computer and use it in GitHub Desktop.
Julia gibbs sampling benchmark
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Gibbs benchmark with Julia\n", | |
"\n", | |
"With Julia we get the best of both worlds: C-like speed and clean, expressive code.\n", | |
"\n", | |
"The benchmark comes from [this blog post](https://darrenjw.wordpress.com/2011/07/16/gibbs-sampler-in-various-languages-revisited/), and is a gibbs sampler for two variables. The un-normalized likelihood is\n", | |
"\n", | |
"$$f(x,y) = kx^2 \\exp(-xy^2 - y^2 + 2y - 4x)$$\n", | |
"\n", | |
"The conditional distributions are\n", | |
"\n", | |
"$$x | y \\sim Ga(3, y^2+4)$$\n", | |
"$$y | x \\sim N \\left( \\frac{1}{1+x}, \\frac{1}{2(1+x)} \\right)$$\n", | |
"\n", | |
"I suspect that most of the time is spent in loop overhead (especially for R/python) and RNG calls, so Julia should do pretty well." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"gibbs (generic function with 1 method)" | |
] | |
}, | |
"execution_count": 1, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"using Distributions\n", | |
"\n", | |
"function gibbs_step(x, y)\n", | |
" x = rand(Gamma(3, 1/(y^2+4)))\n", | |
" y = rand(Normal(1/(x+1), 1/(2*x+2)))\n", | |
" x, y\n", | |
"end\n", | |
"\n", | |
"function gibbs(N, thin)\n", | |
" xs, ys = zeros(N), zeros(N)\n", | |
" x, y = 0.0, 0.0\n", | |
" for i in 1:N\n", | |
" for j in 1:thin\n", | |
" x, y = gibbs_step(x, y)\n", | |
" end\n", | |
" xs[i] = x\n", | |
" ys[i] = y\n", | |
" end\n", | |
" xs, ys\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"BenchmarkTools.Trial: \n", | |
" memory estimate: 781.44 KiB\n", | |
" allocs estimate: 5\n", | |
" --------------\n", | |
" minimum time: 3.536 s (0.00% GC)\n", | |
" median time: 3.545 s (0.00% GC)\n", | |
" mean time: 3.545 s (0.00% GC)\n", | |
" maximum time: 3.554 s (0.00% GC)\n", | |
" --------------\n", | |
" samples: 2\n", | |
" evals/sample: 1" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"using BenchmarkTools\n", | |
"gibbs(1,1)\n", | |
"bt_jl = @benchmark gibbs(50_000, 1_000)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Julia to a file\n", | |
"\n", | |
"We can add a method that takes an `IO` object in order to write the results to a file:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"gibbs (generic function with 2 methods)" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"function gibbs(N, thin, out::IO)\n", | |
" x, y = 0.0, 0.0\n", | |
" println(out, \"i x y\")\n", | |
" for i in 1:N\n", | |
" for j in 1:thin\n", | |
" x, y = gibbs_step(x, y)\n", | |
" end\n", | |
" @printf(out, \"%d %f %f\\n\", i, x, y)\n", | |
" end\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"i x y\n", | |
"1 0.670719 0.547819\n", | |
"2 0.930058 0.371557\n", | |
"3 0.837985 0.349067\n", | |
"4 0.223886 1.040146\n", | |
"5 1.042599 -0.115794\n", | |
"6 0.727126 0.663877\n", | |
"7 0.666670 0.873323\n", | |
"8 0.760967 0.491605\n", | |
"9 0.640121 0.590486\n", | |
"10 0.276760 0.447089\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"BenchmarkTools.Trial: \n", | |
" memory estimate: 384 bytes\n", | |
" allocs estimate: 4\n", | |
" --------------\n", | |
" minimum time: 3.582 s (0.00% GC)\n", | |
" median time: 3.589 s (0.00% GC)\n", | |
" mean time: 3.589 s (0.00% GC)\n", | |
" maximum time: 3.596 s (0.00% GC)\n", | |
" --------------\n", | |
" samples: 2\n", | |
" evals/sample: 1" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# call the method once to avoid compilation time in benchmark\n", | |
"gibbs(10, 2, STDOUT)\n", | |
"\n", | |
"bt_jl_file = @benchmark open(\"data_jl.tab\", \"w\") do f gibbs(50_000, 1_000, f) end " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"BenchmarkTools.TrialRatio: \n", | |
" time: 0.9873731837839069\n", | |
" gctime: 1.0\n", | |
" memory: 2083.8333333333335\n", | |
" allocs: 1.25" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ratio(minimum(bt_jl), minimum(bt_jl_file))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"It's ever so slightly slower to write results out to disk instead of storing in memory, but you do save on memory." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# C\n", | |
"\n", | |
"This is the C implementation using GSL (Gnu Scientific Library) from the original blog post. These timings agree with the shell `time` method from the original post. For reference, the code is\n", | |
"\n", | |
"```\n", | |
"#include <stdio.h>\n", | |
"#include <math.h>\n", | |
"#include <stdlib.h>\n", | |
"#include <gsl/gsl_rng.h>\n", | |
"#include <gsl/gsl_randist.h>\n", | |
" \n", | |
"void main()\n", | |
"{\n", | |
" int N=50000;\n", | |
" int thin=1000;\n", | |
" int i,j;\n", | |
" gsl_rng *r = gsl_rng_alloc(gsl_rng_mt19937);\n", | |
" double x=0;\n", | |
" double y=0;\n", | |
" printf(\"Iter x y\\n\");\n", | |
" for (i=0;i<N;i++) {\n", | |
" for (j=0;j<thin;j++) {\n", | |
" x=gsl_ran_gamma(r,3.0,1.0/(y*y+4));\n", | |
" y=1.0/(x+1)+gsl_ran_gaussian(r,1.0/sqrt(2*x+2));\n", | |
" }\n", | |
" printf(\"%d %f %f\\n\",i,x,y);\n", | |
" }\n", | |
"}\n", | |
"```\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"BenchmarkTools.Trial: \n", | |
" memory estimate: 4.92 KiB\n", | |
" allocs estimate: 145\n", | |
" --------------\n", | |
" minimum time: 4.823 s (0.00% GC)\n", | |
" median time: 4.835 s (0.00% GC)\n", | |
" mean time: 4.835 s (0.00% GC)\n", | |
" maximum time: 4.846 s (0.00% GC)\n", | |
" --------------\n", | |
" samples: 2\n", | |
" evals/sample: 1" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"bt_c = @benchmark run(pipeline(Cmd(`./gibbs`, ignorestatus=true), \"datac.tab\")) samples=3" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"BenchmarkTools.TrialRatio: \n", | |
" time: 0.7331577308566487\n", | |
" gctime: 1.0\n", | |
" memory: 158.76825396825396\n", | |
" allocs: 0.034482758620689655" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ratio(minimum(bt_jl), minimum(bt_c))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"The Julia version runs in less than 75% of the time for the C version!" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# R" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"WARNING: Method definition ==(Base.Nullable{S}, Base.Nullable{T}) in module Base at nullable.jl:238 overwritten in module NullableArrays at /home/dave/.julia/v0.6/NullableArrays/src/operators.jl:128.\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"RCall.RObject{RCall.ClosSxp}\n", | |
"function (N = 50000, thin = 1000) \n", | |
"{\n", | |
" mat = gibbs(N, thin)\n", | |
" write.table(mat, \"data.tab\", row.names = FALSE)\n", | |
"}\n" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"using RCall\n", | |
"R\"\"\"\n", | |
"gibbs=function(N,thin)\n", | |
"{\n", | |
" mat=matrix(0,ncol=3,nrow=N)\n", | |
" mat[,1]=1:N\n", | |
" x=0\n", | |
" y=0\n", | |
" for (i in 1:N) {\n", | |
" for (j in 1:thin) {\n", | |
" x=rgamma(1,3,y*y+4)\n", | |
" y=rnorm(1,1/(x+1),1/sqrt(2*x+2))\n", | |
" }\n", | |
" mat[i,2:3]=c(x,y)\n", | |
" }\n", | |
" mat=data.frame(mat)\n", | |
" names(mat)=c(\"Iter\",\"x\",\"y\")\n", | |
" mat\n", | |
"}\n", | |
" \n", | |
"writegibbs=function(N=50000,thin=1000)\n", | |
"{\n", | |
" mat=gibbs(N,thin)\n", | |
" write.table(mat,\"data.tab\",row.names=FALSE)\n", | |
"}\n", | |
"\"\"\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"190.374805168" | |
], | |
"text/plain": [ | |
"190.374805168" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"@elapsed R\"writegibbs(50000, 1000)\"" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"That's about 60x slowdown." | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Julia 0.6.0", | |
"language": "julia", | |
"name": "julia-0.6" | |
}, | |
"language_info": { | |
"file_extension": ".jl", | |
"mimetype": "application/julia", | |
"name": "julia", | |
"version": "0.6.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment