Last active
August 3, 2018 16:43
-
-
Save tyleransom/e48cc80dda667f034d8dad17b53fa26b to your computer and use it in GitHub Desktop.
Compare run time of Julia and Matlab optimization for simple Maximum Likelihood problem (normal linear model)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function normalMLE(b::Vector,Y::Array,X::Array,w::Array=ones(size(Y)),d::Array=ones(size(Y)),J::Int64=1) | |
@assert size(Y,1)==size(X,2) "X and Y must be the same length" | |
T = promote_type(eltype(b), eltype(X)) | |
like = zero(T) | |
@inbounds for i=1:size(Y,1) | |
xb = zero(T) | |
for j in 1:size(X,1) | |
xb += X[j, i] * b[j] | |
end | |
for k in 1:J | |
like-=w[i]*(d[i]==k)*(-0.5*(log(2*π)+log(b[end-(k-1)]^2)+((Y[i]-xb)/b[end-(k-1)])^2)) | |
end | |
end | |
return like | |
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function [like,grad]=normalMLE(b,restrMat,Y,X,W,d) | |
%NORMALMLE estimates a linear regression model with potentially hetero- | |
% skedastic error variances. | |
% | |
% LIKE = NORMALMLE(B,RESTRMAT,Y,X,D) | |
% estimates a linear regression model with errors assumed to be normal. | |
% The user can specify heteroskedasticity in these error variances. | |
% Parameter restrictions are constructed in RESTRMAT. | |
% | |
% For estimation without restrictions, set RESTRMAT to be an empty matrix. | |
% | |
% RESTRMAT is an R x 5 matrix of parameter restrictions. See APPLYRESTR | |
% for more information in using this feature. If no parameter | |
% restrictions are desired, RESTRMAT should be passed as an empty matrix | |
% Y is an N x 1 vector of outcomes. | |
% X is an N x K matrix of covariates. | |
% W is an N x 1 vector of weights. | |
% D is an N x 1 vector of integers that indicates which variance group | |
% an observation falls into (heteroskedasticity case). If homoskedastic | |
% errors are assumed, D may be left unpassed or passed as an empty matrix | |
% B is the parameter vector, with K + numel(unique(D)) elements | |
% | |
% This function does *not* automatically include a column of ones in X. | |
% It also does *not* automatically drop NaNs | |
% Copyright 2014 Tyler Ransom, Duke University | |
% Revision History: | |
% September 25, 2014 | |
% Created | |
% September 29, 2014 | |
% Added code for weighted estimation | |
%========================================================================== | |
% error checking | |
assert(size(X,1)==size(Y,1),'X and Y must be the same length'); | |
if nargin==5 | |
d = ones(size(Y)); | |
elseif nargin==6 && isempty(d) | |
d = ones(size(Y)); | |
end | |
J = numel(unique(d)); | |
assert( min(d)==1 && max(d)==J ,'d should contain integers numbered consecutively from 1 through J'); | |
assert(size(X,2)+J==size(b,1),'parameter vector has wrong number of elements'); | |
if ~isempty(W) | |
assert(isvector(W) && length(W)==length(Y),'W must be a column vector the same size as Y'); | |
else | |
W = ones(size(Y)); | |
end | |
% apply restrictions as defined in restrMat | |
if ~isempty(restrMat) | |
b = applyRestr(restrMat,b); | |
end | |
% slice parameter vector | |
beta = b(1:end-J); | |
wagesigma = b(end-(J-1):end); | |
n = length(Y); | |
% log likelihood | |
likemat = zeros(n,J); | |
dmat = zeros(n,J); | |
for j=1:J | |
dmat(:,j) = d==j; | |
likemat(:,j) = -.5*(log(2*pi)+log(wagesigma(j)^2)+((Y-X*beta)./wagesigma(j)).^2); | |
end | |
like = -W'*sum(dmat.*likemat,2); | |
% analytical gradient | |
grad = zeros(size(b)); | |
for j=1:J | |
grad(1:end-J) = -X'*(W.*(d==j).*(Y-X*beta)./(wagesigma(j).^2)) + grad(1:end-J); | |
end | |
for j=1:J | |
k=length(b)-(J-1)+j-1; | |
temp = 1./wagesigma(j)-((Y-X*beta).^2)./(wagesigma(j).^3); | |
grad(k) = sum(W.*(d==j).*temp); | |
end | |
% apply restrictions to gradient | |
if ~isempty(restrMat) | |
grad = applyRestrGrad(restrMat,grad); | |
end | |
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Simple Monte Carlo simulation of normal linear model | |
using Optim | |
include("normalMLE.jl") | |
function datagen_test_normal(N::Int64=10_000,T::Int64=5) | |
## Generate data for a linear model to test optimization | |
srand(1234) | |
n = N*T | |
ID = collect(1:N)*ones(1,T) | |
Tw = ones(N,1)*collect(1:T)' | |
# generate the covariates | |
X = [ones(n,1) 5.+3.*randn(n,1) rand(n,1) 2.5.+2.*randn(n,1) 1.5.+0.5.*randn(n,1) rand(n,1).>0.5 5.*rand(n,1)] | |
# lnWage coefficients | |
ωans = cat(1,-0.15,0.10,0.50,0.10,-.15,1,0.2 ) | |
σans = .5 | |
# generate wages | |
Xwage = X | |
lnWage = Xwage*ωans+σans*randn(n,1) | |
# return generated data so that other functions (below) have access | |
return Xwage,lnWage,ωans,σans,n,ID,Tw | |
end | |
function tester_normal() | |
# Simulate data | |
N = 100_000 | |
T = 5 | |
@time Xwage,lnWage,ωans,σans,n,ID,Tw = datagen_test_normal(N,T) | |
# starting values | |
ωstart = cat(1,ωans,σans)+.25*rand(size(cat(1,ωans,σans))).*cat(1,ωans,σans)-.125*cat(1,ωans,σans) | |
ωstart = ones(size(cat(1,ωans,σans))) | |
# optimization | |
funw = TwiceDifferentiable((arg)->normalMLE(arg,lnWage,Xwage'), ωstart; autodiff = :forward) | |
@time res = optimize(funw, ωstart, LBFGS(), Optim.Options(show_trace=true,iterations=100_000,g_tol=1e-6,f_tol=1e-6)) | |
ωest = res.minimizer | |
# compare answers | |
println(ωest) | |
println(cat(1,ωans,σans)) | |
println(cat(2,ωest,cat(1,ωans,σans))) | |
return nothing | |
end | |
tester_normal() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
julia> include("test_normal.jl") | |
0.043029 seconds (329 allocations: 93.152 MiB, 30.10% gc time) | |
Iter Function value Gradient norm | |
0 3.884509e+07 7.627124e+07 | |
1 1.322658e+07 1.430237e-01 | |
2 9.660097e+06 5.090879e-03 | |
3 9.659634e+06 5.084218e-03 | |
4 9.659361e+06 5.090251e-03 | |
5 9.656354e+06 5.129389e-03 | |
6 9.645795e+06 5.092902e-03 | |
7 9.559772e+06 1.316707e-02 | |
8 9.530747e+06 1.366613e-02 | |
9 9.489493e+06 9.718868e-03 | |
10 9.468099e+06 6.769545e-03 | |
11 9.402472e+06 6.883134e-03 | |
12 9.356621e+06 2.873576e-02 | |
13 9.313011e+06 4.545268e-02 | |
14 9.209566e+06 1.213583e-01 | |
15 9.149203e+06 1.501674e-01 | |
16 9.019248e+06 2.372364e-01 | |
17 8.993140e+06 1.667321e-01 | |
18 8.943569e+06 1.233534e-01 | |
19 8.814717e+06 5.663536e-02 | |
20 8.791966e+06 3.974257e-02 | |
21 8.774412e+06 2.816288e-02 | |
22 8.717550e+06 2.257955e-02 | |
23 8.692524e+06 4.895406e-02 | |
24 8.638723e+06 8.078320e-02 | |
25 8.524785e+06 2.012978e-01 | |
26 8.424543e+06 2.381231e-01 | |
27 8.294458e+06 2.251018e-01 | |
28 8.213579e+06 1.092005e-01 | |
29 8.149474e+06 7.973280e-02 | |
30 8.137504e+06 8.302497e-02 | |
31 8.116011e+06 1.538212e-01 | |
32 8.108217e+06 1.163106e-01 | |
33 8.091333e+06 7.106232e-02 | |
34 8.069038e+06 1.682824e-01 | |
35 8.010527e+06 3.989050e-01 | |
36 7.942266e+06 7.174701e-01 | |
37 7.856672e+06 1.321825e+00 | |
38 7.798530e+06 1.758977e+00 | |
39 7.717447e+06 1.906455e+00 | |
40 7.631388e+06 1.295352e+00 | |
41 7.552964e+06 4.504293e-01 | |
42 7.540092e+06 3.469435e-01 | |
43 7.513918e+06 1.837015e-01 | |
44 7.502158e+06 9.668569e-02 | |
45 7.495602e+06 1.231136e-01 | |
46 7.449756e+06 6.853826e-01 | |
47 7.413579e+06 5.110502e-01 | |
48 7.338916e+06 1.477632e+00 | |
49 7.078971e+06 5.434213e+00 | |
50 7.032346e+06 7.019024e+00 | |
51 6.901474e+06 7.682133e+00 | |
52 6.877113e+06 8.384756e+00 | |
53 6.769102e+06 8.379340e+00 | |
54 6.703194e+06 9.658427e+00 | |
55 6.690388e+06 9.143556e+00 | |
56 6.612149e+06 6.786788e+00 | |
57 6.493374e+06 2.614001e+00 | |
58 6.447995e+06 3.433311e+01 | |
59 6.350980e+06 8.234535e+00 | |
60 6.144345e+06 1.273747e+01 | |
61 6.080728e+06 1.353925e+01 | |
62 6.005228e+06 1.137702e+01 | |
63 5.960653e+06 1.096689e+01 | |
64 5.905861e+06 1.922899e+01 | |
65 5.800736e+06 5.977742e+01 | |
66 5.771706e+06 6.316197e+01 | |
67 5.663032e+06 1.225153e+02 | |
68 5.617952e+06 1.403519e+02 | |
69 5.460814e+06 1.596954e+02 | |
70 5.429030e+06 1.553969e+02 | |
71 5.353603e+06 1.842261e+02 | |
72 5.331645e+06 1.629038e+02 | |
73 5.233476e+06 1.306780e+02 | |
74 5.135056e+06 5.338886e+01 | |
75 5.062761e+06 6.157298e+01 | |
76 4.977962e+06 1.742760e+02 | |
77 4.857330e+06 1.991023e+02 | |
78 4.794941e+06 3.268310e+02 | |
79 4.675018e+06 4.938033e+02 | |
80 4.616066e+06 7.721037e+02 | |
81 4.536332e+06 7.066519e+02 | |
82 4.436529e+06 4.811221e+02 | |
83 4.361249e+06 3.598480e+02 | |
84 4.257312e+06 3.579856e+02 | |
85 4.221750e+06 5.430751e+02 | |
86 4.147231e+06 8.183836e+02 | |
87 4.080882e+06 1.833556e+03 | |
88 4.047268e+06 1.365756e+03 | |
89 3.988578e+06 5.565152e+02 | |
90 3.807426e+06 1.041928e+03 | |
91 3.736149e+06 2.443530e+03 | |
92 3.667404e+06 3.493102e+03 | |
93 3.594433e+06 5.597260e+03 | |
94 3.505059e+06 5.115947e+03 | |
95 3.481843e+06 6.213013e+03 | |
96 3.367955e+06 5.722373e+03 | |
97 3.349488e+06 5.541825e+03 | |
98 3.258341e+06 7.257833e+03 | |
99 3.225907e+06 6.477123e+03 | |
100 3.152824e+06 4.756296e+03 | |
101 3.072572e+06 7.422809e+03 | |
102 3.036155e+06 9.333744e+03 | |
103 2.893977e+06 2.587831e+04 | |
104 2.858422e+06 2.664313e+04 | |
105 2.776338e+06 3.849096e+04 | |
106 2.766183e+06 1.659702e+04 | |
107 2.696631e+06 4.951878e+03 | |
108 2.634772e+06 4.960439e+03 | |
109 2.547868e+06 1.109479e+04 | |
110 2.491804e+06 1.798263e+04 | |
111 2.405218e+06 4.130854e+04 | |
112 2.312243e+06 9.378387e+04 | |
113 2.214750e+06 1.525488e+05 | |
114 2.159427e+06 1.742399e+05 | |
115 2.137466e+06 1.762220e+05 | |
116 2.060856e+06 1.255663e+05 | |
117 2.009149e+06 1.790481e+05 | |
118 1.962359e+06 1.096074e+05 | |
119 1.801608e+06 2.790790e+04 | |
120 1.667714e+06 9.962262e+04 | |
121 1.622656e+06 1.371471e+05 | |
122 1.472104e+06 3.146925e+05 | |
123 1.387777e+06 4.529169e+05 | |
124 1.312364e+06 3.499773e+05 | |
125 1.203337e+06 3.985834e+05 | |
126 1.119908e+06 3.141713e+05 | |
127 1.021093e+06 3.662530e+05 | |
128 9.415379e+05 2.809368e+05 | |
129 8.459852e+05 3.302823e+05 | |
130 8.086722e+05 2.718814e+05 | |
131 7.093739e+05 3.426377e+05 | |
132 6.753013e+05 4.039689e+05 | |
133 5.909335e+05 7.410199e+05 | |
134 5.704846e+05 1.656434e+06 | |
135 5.001744e+05 8.380706e+05 | |
136 4.186622e+05 2.837283e+05 | |
137 4.075684e+05 2.503589e+05 | |
138 3.783191e+05 1.828466e+05 | |
139 3.750234e+05 1.670079e+05 | |
140 3.651383e+05 1.231631e+05 | |
141 3.623230e+05 2.702870e+04 | |
142 3.620152e+05 1.142678e+04 | |
143 3.620019e+05 5.842072e+03 | |
144 3.620003e+05 1.631356e+03 | |
145 3.619997e+05 2.072261e+03 | |
146 3.619994e+05 2.166296e+03 | |
25.287109 seconds (119.07 k allocations: 21.564 GiB, 1.23% gc time) | |
[-0.149182, 0.0999476, 0.503172, 0.0996529, -0.150759, 0.999656, 0.199725, 0.499142] | |
[-0.15, 0.1, 0.5, 0.1, -0.15, 1.0, 0.2, 0.5] | |
[-0.149182 -0.15; 0.0999476 0.1; 0.503172 0.5; 0.0996529 0.1; -0.150759 -0.15; 0.999656 1.0; 0.199725 0.2; 0.499142 0.5] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
% Simple simulation and estimation of normal linear model | |
clear all; clc; | |
tic; | |
seed = 1234; | |
rng(seed,'twister'); | |
N = 1e5; | |
T = 5; | |
% generate the regressors | |
Xwage = [ones(N*T,1) 5+3*randn(N*T,1) rand(N*T,1) 2.5+2*randn(N*T,1) 1.5+0.5*randn(N*T,1) rand(N*T,1)>0.5 5*rand(N*T,1)]; | |
% lnWage coefficients | |
bwAns(:,1) = [-0.15;0.10;0.50;0.10;-.15;1;0.2]; | |
sigWans = .5; | |
% generate wages | |
lnWage = Xwage*bwAns+sigWans*randn(N*T,1); | |
disp(['Time spent on simulation: ',num2str(toc),' seconds']); | |
options=optimset('Disp','Iter','LargeScale','on','MaxFunEvals',2000000,'MaxIter',15000,'TolX',1e-6,'Tolfun',1e-6,'GradObj','on','DerivativeCheck','off','FinDiffType','central'); | |
tic; | |
% EM algorithm starting values | |
bwEst = [bwAns;sigWans] + .5*rand(length(bwAns)+1,1).*[bwAns;sigWans] - .25*[bwAns;sigWans]; | |
bwEst = rand(size([bwAns;sigWans])); | |
% Optimization | |
[bwEst] = fminunc('normalMLE',bwEst,options,[],lnWage,Xwage,[]); %ones(size(lnWage)) | |
% % re-estimate to get wage model hessian for statistical inference | |
% [bwEst,lwEst,~,~,~,hwEst] = fminunc('normalMLE',bwEst,options,[],lnWfeas,Xwagefeas,Ptypel); | |
% hwEst = full(hwEst); | |
% se = sqrt(diag(inv(hwEst))); | |
% [bwEst cat(1,bwAns,sigWans)] | |
% [bwEst se] | |
[bwEst cat(1,bwAns,sigWans)] | |
disp(['Time spent on estimation: ',num2str(toc),' seconds']); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Time spent on simulation: 0.072481 seconds | |
Norm of First-order | |
Iteration f(x) step optimality CG-iterations | |
0 4.13512e+07 1.21e+08 | |
1 4.13512e+07 10 1.21e+08 2 | |
2 4.13512e+07 2.5 1.21e+08 0 | |
3 1.4052e+07 0.625 3.73e+07 0 | |
4 4.72701e+06 1.25 2.32e+07 2 | |
5 4.72701e+06 1.25 2.32e+07 2 | |
6 1.40648e+06 0.3125 6.85e+06 0 | |
7 697859 0.625 2.83e+06 4 | |
8 452775 0.579083 1.36e+06 3 | |
9 388340 0.0731217 4.28e+05 2 | |
10 368925 0.111206 1.5e+05 3 | |
11 362901 0.173123 2.59e+04 4 | |
12 362686 0.0237356 6.9e+03 4 | |
13 362675 0.00676573 2.15e+03 4 | |
14 362674 0.00137254 701 4 | |
15 362674 0.000888129 281 4 | |
Local minimum possible. | |
fminunc stopped because the final change in function value relative to | |
its initial value is less than the selected value of the function tolerance. | |
ans = | |
-0.1484 -0.1500 | |
0.0997 0.1000 | |
0.4999 0.5000 | |
0.1000 0.1000 | |
-0.1490 -0.1500 | |
1.0012 1.0000 | |
0.1995 0.2000 | |
0.4998 0.5000 | |
Time spent on estimation: 5.3942 seconds |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment