Created
February 16, 2017 04:58
-
-
Save willtownes/fec2de381fbd6fc357a0d99b6e234dfd to your computer and use it in GitHub Desktop.
Adaptive Rejection Sampling without Derivatives
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
### Adaptive Rejection Sampling | |
# by Will Townes | |
rexp_trunc<-function(n,slope=-1,lo=0,hi=Inf){ | |
#draw n samples from the truncated exponential distribution | |
#the distribution is proportional to exp(slope*x) | |
#default is standard exponential | |
#slope cannot equal zero | |
#lo is lower truncation point, can be -Inf | |
#hi is upper truncation point, can be +Inf | |
u<-runif(n) | |
if(lo== -Inf){ | |
stopifnot(slope>0 && hi<Inf) | |
return(hi+log(u)/slope) | |
} else if(hi==Inf){ | |
stopifnot(slope<0 && lo> -Inf) | |
return(lo+log(u)/slope) | |
} else { | |
stopifnot(slope != 0) | |
} | |
lo+log1p(u*expm1(slope*(hi-lo)))/slope | |
} | |
ars_wt_formula<-function(a,b,c,d,numeric_zero=1e-8){ | |
#compute formula for normalization constant of a sub-interval | |
#(1/a)exp(b)(exp(ad)-exp(ac)) | |
#require that c <= d | |
#a is slope of line, b is intercept | |
#c, d are left, right endpoints of the interval | |
dmc<-d-c | |
stopifnot(dmc>=0) | |
if(abs(a)<numeric_zero){ return(exp(b)*dmc)} | |
(exp(b)/a)*(exp(a*d)-exp(a*c)) #improve numeric stability? | |
} | |
calc_wts_inner<-function(j,xpts,xstar,a,b){ | |
#calculate weights for possibly two sub-intervals of interval j | |
#note that edge sub-intervals have weight zero | |
#prevents sampling sub-intervals outside range | |
#xstar,a,b must have same length | |
#xpts must have length one more than others | |
#xpts are boundaries of intervals | |
#xstar indicates "break point" within each interval | |
#a indicates slopes of interpolating lines within intervals | |
#b indicates intercepts of interpolating lines | |
wtj<-c(0,0) | |
if(xstar[j]==xpts[j]){ | |
wtj[2]<-ars_wt_formula(a[j+1],b[j+1],xstar[j],xpts[j+1]) | |
} else if(xstar[j]==xpts[j+1]){ | |
wtj[1]<-ars_wt_formula(a[j-1],b[j-1],xpts[j],xstar[j]) | |
} else { | |
wtj[1]<-ars_wt_formula(a[j-1],b[j-1],xpts[j],xstar[j]) | |
wtj[2]<-ars_wt_formula(a[j+1],b[j+1],xstar[j],xpts[j+1]) | |
} | |
wtj | |
} | |
calc_wts<-function(xpts,xstar,a,b,lo,hi,idx=seq_along(a)){ | |
nIvl<-length(idx) | |
#stopifnot(nIvl>1) | |
w<-vector("numeric",2*nIvl+2) #indexed by j, all sub-intervals | |
w[1]<-ars_wt_formula(a[1],b[1],lo,xpts[1]) #left interval, lo can be -Inf | |
w[2*nIvl+2]<-ars_wt_formula(a[nIvl],b[nIvl],xpts[nIvl+1],hi) #right int., hi can be +Inf | |
w[2:(2*nIvl+1)]<-unlist(lapply(idx,calc_wts_inner,xpts,xstar,a,b)) | |
#un-normalized weights | |
w | |
} | |
subinterval_to_interval<-function(j){floor(j/2)} | |
get_xstar<-function(xpts,a,b,idx=seq_along(a)){ | |
#provide grid points xpts, slopes "a" and intercepts "b" for each line segment | |
#returns the breakpoints for the envelope function "xstar" | |
nIvl<-length(idx) | |
xstar<-xpts[idx] #initialize, edge case on left | |
xstar[nIvl]<-xpts[nIvl+1] #edge case on right | |
if(nIvl>2) xstar[2:(nIvl-1)]<- -diff(b,2)/diff(a,2) | |
#handle annoying edge-cases that occur for almost-linear regions of function | |
#due to rounding errors, can lead to negative weights if not addressed | |
too_low<-which(xstar < xpts[idx]) | |
xstar[too_low]<-xpts[too_low] | |
too_hi<- which(xstar > xpts[idx+1]) | |
xstar[too_hi]<-xpts[too_hi+1] | |
xstar | |
} | |
ars<-function(func,nSample=1,xpts,lo=0,hi=1,logscale=TRUE,verbose=TRUE){ | |
#sample nSample times from univariate function f | |
#f(x) must be log-concave. | |
#If logscale==TRUE, assume log(f(x)) is provided instead of f(x) | |
#xpts are a grid of points to construct envelope function, must be >= 3 points! | |
#xpts must be in region of positive probability for f(x) | |
#lo,hi are lower,upper bounds of domain of integration of f, may be -Inf or +Inf | |
h<-if(logscale) func else function(x){log(func(x))} | |
xpts<-sort(xpts) | |
nIvl<-length(xpts)-1 | |
ypts<-h(xpts) | |
stopifnot(all(ypts>-Inf)) | |
a<-diff(ypts)/diff(xpts) | |
idx<-seq.int(nIvl) #index intervals by left endpoint | |
#one fewer intervals than points | |
b<-ypts[idx]-a*xpts[idx] | |
#compute "breakpoints" within each interval | |
xstar<-get_xstar(xpts,a,b,idx) | |
#handle problem values when outer intervals have infinities | |
#if(is.nan(xstar[2])) xstar[2]<-xpts[2] | |
#if(is.nan(xstar[nIvl-1])) xstar[nIvl-1]<-xpts[nIvl] | |
#compute weights for each sub-interval | |
w<-calc_wts(xpts,xstar,a,b,lo,hi,idx) | |
#each element of w is a sub interval | |
#w has length 2*nIvl+2 (2 sub-intervals per interval, and a left and right outside intervals) | |
#w<-w/sum(w) #normalization | |
res<-rep(NA,nSample) | |
nRes<-0 | |
while(nRes < nSample){ | |
#choose index from multinomial probs | |
j<-sample.int(2*nIvl+2,1,prob=w) #index of a subinterval | |
i<-subinterval_to_interval(j) #between zero and nIvl+1, inclusive | |
is_right<- as.logical(j%%2) | |
if(is_right){ | |
#c for "current" | |
c_slope<-a[i+1]; c_hi<-xpts[i+1]; c_icpt<-b[i+1] | |
c_lo<-if(i>0){ xstar[i] }else{ lo } | |
} else { #case of left sub-interval | |
c_slope<-a[i-1]; c_lo<-xpts[i]; c_icpt<-b[i-1] | |
c_hi<-if(i<=nIvl){ c_hi<-xstar[i] }else{ hi } | |
} | |
x<-rexp_trunc(1,slope=c_slope,lo=c_lo,hi=c_hi) | |
#x is a valid sample from the upper envelope function | |
#next, do the accept/reject step | |
hx<-h(x) | |
#adding rexp(1) equiv to subtracting log(uniform(0,1)) | |
accpt<- (c_slope*x+c_icpt <= hx + rexp(1)) | |
if(accpt){ #accepted sample | |
nRes<-nRes+1 | |
res[nRes]<-x | |
} else { #rejected sample | |
#print("rejected!") | |
#insert x into xpts and update all statistics | |
nIvl<-nIvl+1; idx<-append(idx,nIvl) #max possible j is (nIvl-1) | |
xpts<-append(xpts,x,i) #preserves ordering | |
ypts<-append(ypts,hx,i) | |
a<-append(a,NA,i); b<-append(b,NA,i) | |
if(i<nIvl){ #includes possibly outer left interval i=0 | |
a[i+1]<-(ypts[i+2]-ypts[i+1])/(xpts[i+2]-xpts[i+1]) | |
b[i+1]<-ypts[i+1]-a[i+1]*xpts[i+1] | |
} | |
if(i>0){ #includes possibly outer right interval i=nIvl | |
a[i]<-(ypts[i+1]-ypts[i])/(xpts[i+1]-xpts[i]) | |
b[i]<-ypts[i]-a[i]*xpts[i] | |
} | |
#to do: make more efficient by only updating parts of xstar, w that change | |
xstar<-get_xstar(xpts,a,b,idx) | |
#debugging | |
#if(any(xstar<xpts[idx]) || any(xstar>xpts[idx+1])){ | |
# print(paste0("xstar=",xstar)) | |
# print(paste0("xpts=",xpts)) | |
#} | |
w<-calc_wts(xpts,xstar,a,b,lo,hi,idx) | |
} | |
} | |
#if(verbose) print(signif(xpts,2)) | |
res | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment