Last active
March 15, 2024 19:05
-
-
Save yanatan16/5420795 to your computer and use it in GitHub Desktop.
Simultaneous Perturbation Stochastic Approximation code in python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Simultaneous Perturbation Stochastic Approximation | |
Author: Jon Eisen | |
License: MIT | |
This code defines runs SPSA using iterators. | |
A quick intro to iterators: | |
Iterators are like arrays except that we don't store the whole array, we just | |
store how to get to the next element. In this way, we can create infinite | |
iterators. In python, iterators can act very similar to arrays. | |
numpy (a number processing library) is not used here so that pypy (an alternate | |
python implementation which is faster) can be used. | |
''' | |
from itertools import count, izip | |
# A simple function that returns its argument | |
identity = lambda x: x | |
def SPSA(y, t0, a, c, delta, constraint=identity): | |
''' | |
Creates an Simultaneous Perturbation Stochastic Approximation iterator. | |
y - a function of theta that returns a scalar | |
t0 - the starting value of theta | |
a - an iterable of a_k values | |
c - an iterable of c_k values | |
delta - a function of no parameters which creates the delta vector | |
constraint - a function of theta that returns theta | |
''' | |
theta = t0 | |
# Pull off the ak and ck values forever | |
for ak, ck in izip(a, c): | |
# Get estimated gradient | |
gk = estimate_gk(y, theta, delta, ck) | |
# Adjust theta using SA | |
theta = [t - ak * gkk for t, gkk in izip(theta, gk)] | |
# Constrain | |
theta = constraint(theta) | |
yield theta # This makes this function become an iterator | |
def estimate_gk(y, theta, delta, ck): | |
'''Helper function to estimate gk from SPSA''' | |
# Generate Delta vector | |
delta_k = delta() | |
# Get the two perturbed values of theta | |
# list comprehensions like this are quite nice | |
ta = [t + ck * dk for t, dk in izip(theta, delta_k)] | |
tb = [t - ck * dk for t, dk in izip(theta, delta_k)] | |
# Calculate g_k(theta_k) | |
ya, yb = y(ta), y(tb) | |
gk = [(ya-yb) / (2*ck*dk) for dk in delta_k] | |
return gk | |
def standard_ak(a, A, alpha): | |
'''Create a generator for values of a_k in the standard form.''' | |
# Parentheses makes this an iterator comprehension | |
# count() is an infinite iterator as 0, 1, 2, ... | |
return ( a / (k + 1 + A) ** alpha for k in count() ) | |
def standard_ck(c, gamma): | |
'''Create a generator for values of c_k in the standard form.''' | |
return ( c / (k + 1) ** gamma for k in count() ) | |
class Bernoulli: | |
''' | |
Bernoulli Perturbation distributions. | |
p is the dimension | |
+/- r are the alternate values | |
''' | |
def __init__(self, r=1, p=2): | |
self.p = p | |
self.r = r | |
def __call__(self): | |
return [random.choice((-self.r, self.r)) for _ in xrange(self.p)] | |
class LossFunction: | |
''' A base class for loss functions which defines y as L+epsilon ''' | |
def y(self, theta): | |
return self.L(theta) + self.epsilon(theta) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from spsa import * | |
from itertools import islice, izip, tee | |
import random | |
def nth(iterable, n, default=None): | |
"Returns the nth item or a default value" | |
return next(islice(iterable, n, None), default) | |
class SkewedQuarticLoss(LossFunction): | |
''' | |
Skewed Quartic Loss function. | |
Initialize with vector length p. | |
Functions, L, y, and epsilon available | |
''' | |
def __init__(self, p, sigma): | |
x = 1./p | |
self.B = [[x if i >= j else 0 for i in xrange(p)] for j in xrange(p)] | |
self.sigmasq = sigma ** 2 | |
def L(self, theta): | |
bt = [dot(Br, theta) for Br in self.B] | |
return dot(bt,bt) + sum((.1 * b**3 + .01 * b**4 for b in bt)) | |
def epsilon(self, theta): | |
return random.gauss(0, self.sigmasq) # multiply by stdev | |
def run_spsa(n=1000, replications=40): | |
p = 20 | |
loss = SkewedQuarticLoss(p, sigma=1) | |
theta0 = [1 for _ in xrange(p)] | |
c = standard_ck(c=1, gamma=.101) | |
a = standard_ak(a=1, A=100, alpha=.602) | |
delta = Bernoulli(p=p) | |
# tee is a useful function to split an iterator into n independent runs of that iterator | |
ac = izip(tee(a,n),tee(c,n)) | |
losses = [] | |
for a, c in islice(ac, replications): | |
theta_iter = SPSA(a=a, c=c, y=loss.y, t0=theta0, delta=delta) | |
terminal_theta = nth(theta_iter, n) # Get 1000th theta | |
terminal_loss = loss.L(terminal_theta) | |
losses += [terminal_loss] | |
return losses # You can calculate means/variances from this data. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment