Skip to content

Instantly share code, notes, and snippets.

@naoyat
Created November 13, 2012 12:51
Show Gist options
  • Save naoyat/4065614 to your computer and use it in GitHub Desktop.
Save naoyat/4065614 to your computer and use it in GitHub Desktop.
PRML図6.8のパラメータをsimulated annealingで求めたい
1; %% これがなくて関数定義から始まるとOctaveではファイル名と関数名が違うぞと怒られるので、スクリプトは1;から始める
function [score] = score68(theta)
global data_x;
global data_t;
global N;
global mu;
M = 6;
x = linspace(0, 1, M);
K = outer(data_x', data_x', @k);
C_N = K + eye(N)/theta(5); % 10x10
m = zeros(M,1);
s = zeros(M,1);
sine = sin(x*2*pi);
for i = 1:M
x_ = x(i);
kvec = zeros(N,1);
for j = 1:N
kvec(j) = k(data_x(j), x_);
endfor
c = k(x_, x_) + 1/theta(5);
kCN = kvec' * inv(C_N);
m(i) = kCN * data_t;
s(i) = sqrt(c - kCN * kvec);
endfor
lo = m' + s'*2;
hi = m' - s'*2;
score = 0;
function [value] = penalty(a, b)
value = abs(a-b);
endfunction
score += ((theta(1) - mu(1))/0.15)^2;
score += ((theta(2) - mu(2))/10)^2;
score += ((theta(3) - mu(3))/1)^2;
score += ((theta(4) - mu(4))/0.3)^2;
score += ((theta(5) - mu(5))/20)^2;
score += penalty(x(1), 0.45);
score += penalty(x(2), 0.93);
score += penalty(x(3), 0.57);
score += penalty(x(4), -0.35);
score += penalty(x(5), -0.79);
score += penalty(x(6), -0.54);
score += penalty(hi(1), 0.87);
score += penalty(hi(2), 1.25);
score += penalty(hi(3), 0.92);
score += penalty(hi(4), 0);
score += penalty(hi(5), 0.05);
score += penalty(hi(6), 0.79);
if 0
if x(1) < sine(1)
score *= 3;
endif
if sine(2) < x(2)
score *= 2;
endif
if hi(3) < sine(3) || sine(3) < lo(3)
score *= 3;
endif
if x(4) < sine(4)
score *= 3;
endif
if x(5) < sine(5) || sine(5) < lo(5)
score *= 3;
endif
if sine(6) < x(6)
score *= 3;
endif
endif
endfunction
clear
data = load("curvefitting.txt");
data = data(1:7, :); %% 一番右から3つを捨てる
global data_x = data(:,1);
global data_t = data(:,2);
global N = length(data_x);
global theta = [0 0 0 0 0]
function [retval] = k(x1, x2)
global theta;
retval = theta(1)*exp(-theta(2)/2*(x1-x2)^2) + theta(3) + theta(4)*x1'*x2;
endfunction
global mu = [0.3 25 0 0.15 40];
theta = mu;
sigma = [0.15 10 1 0.3 20];
score = 99999999;
iter_max = 1000000;
for iter = 1:iter_max
th = zeros(5,1);
invalid = 0;
for j=1:5
th(j) = theta(j) + sigma(j)*randn;
if th(j) < 0
invalid = 1;
break;
endif
endfor
if invalid == 1
continue
endif
curr_theta = th;
curr_score = score68(curr_theta);
function [p] = P(e, en, t)
if en < e
p = 1;
else
p = exp((e - en)/t);
endif
endfunction
if rand < P(score, curr_score, (iter_max-iter)/iter_max)
score = curr_score;
theta = curr_theta;
printf("[%d] theta=[%g %g %g %g], beta=%g ; score=%g¥n", iter, theta(1), theta(2), theta(3), theta(4), theta(5), score)
fflush(stdout);
K = outer(data_x', data_x', @k);
C_N = K + eye(N)/theta(5);
M_ = 100;
xA = linspace(0, 1, M_);
mA = zeros(M_, 1);
sA = zeros(M_, 1);
for i = 1:M_
x_ = xA(i);
kvec = zeros(N,1);
for j = 1:N
kvec(j) = k(data_x(j), x_);
endfor
c = k(x_, x_) + 1/theta(5);
mA(i) = kvec' * inv(C_N) * data_t;
sA(i) = sqrt(c - kvec' * inv(C_N) * kvec);
endfor
clf reset;
hold on;
axis([-0.05, 1.05, -1.3, 1.3])
plot(data_x, data_t, 'o');
gray = [0.7 0.7 0.7];
fill([xA fliplr(xA)], [mA'-2*sA' fliplr(mA'+2*sA')], gray)
plot(xA, mA, 'b:')
plot(data_x, data_t, 'o')
plot(xA, sin(xA*2*pi), 'r');
hold off;
pause;
endif
endfor
theta
score
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment