Skip to content

Instantly share code, notes, and snippets.

@bryangoodrich
Last active August 29, 2015 14:05
Show Gist options
  • Save bryangoodrich/cdc57f6fad4ddd99a854 to your computer and use it in GitHub Desktop.
Save bryangoodrich/cdc57f6fad4ddd99a854 to your computer and use it in GitHub Desktop.
Gradient Descent for Logistic Classifier
gradientDescent <- function(X, y, initial_theta, method = "BFGS", ...) {
m <- nrow(y)
sigmoid <- function(x) 1 / (1 + exp(-x))
gradFunction <- function(theta) (1/m) * (t(X) %*% (sigmoid(X %*% theta)-y))
costFunction <- function(theta)
(1/m) * (t(-y) %*% log(sigmoid(X %*% theta)) - t(1-y) %*% log(1 - sigmoid(X %*% theta)))
optim(initial_theta, costFunction, gradFunction, method = method, ...)
}
loadInput <- function() {
structure(c(34.623659624517, 30.2867107682261, 35.8474087699387,
60.1825993862098, 79.0327360507101, 45.0832774766834, 61.1066645368477,
75.0247455673889, 76.0987867022626, 84.4328199612004, 95.8615550709357,
75.0136583895825, 82.3070533739948, 69.3645887597094, 39.5383391436722,
53.9710521485623, 69.0701440628303, 67.9468554771162, 70.6615095549944,
76.978783727475, 67.3720275457088, 89.6767757507208, 50.534788289883,
34.2120609778679, 77.9240914545704, 62.2710136700463, 80.1901807509566,
93.114388797442, 61.830206023126, 38.7858037967942, 61.379289447425,
85.4045193941164, 52.1079797319398, 52.0454047683183, 40.2368937354511,
54.6351055542482, 33.9155001090689, 64.1769888749449, 74.7892529594154,
34.1836400264419, 83.9023936624915, 51.5477202690618, 94.4433677691785,
82.3687537571392, 51.0477517712887, 62.2226757612019, 77.1930349260136,
97.7715992800023, 62.0730637966765, 91.5649744980744, 79.9448179406693,
99.2725269292572, 90.5467141139985, 34.5245138532001, 50.2864961189907,
49.5866772163203, 97.6456339600777, 32.5772001680931, 74.248691367216,
71.7964620586338, 75.3956114656803, 35.2861128152619, 56.2538174971162,
30.058822446698, 44.6682617248089, 66.5608944724295, 40.4575509837516,
49.0725632190884, 80.27957401467, 66.7467185694404, 32.7228330406032,
64.0393204150601, 72.3464942257992, 60.4578857391896, 58.840956217268,
99.8278577969213, 47.2642691084817, 50.4581598028599, 60.4555562927153,
82.2266615778557, 88.9138964166533, 94.834506724302, 67.3192574691753,
57.2387063156986, 80.3667560017127, 68.4685217859111, 42.0754545384731,
75.4777020053391, 78.6354243489802, 52.3480039879411, 94.0943311251679,
90.4485509709636, 55.4821611406959, 74.4926924184304, 89.8458067072098,
83.4891627449824, 42.2617008099817, 99.3150088051039, 55.340017560037,
74.7758930009277, 78.0246928153624, 43.894997524001, 72.9021980270836,
86.3085520954683, 75.3443764369103, 56.3163717815305, 96.5114258848962,
46.5540135411654, 87.420569719268, 43.5333933107211, 38.2252780579509,
30.6032632342801, 76.481963302356, 97.7186919618861, 76.0368108511588,
89.207350137502, 52.7404697301677, 46.6785741067313, 92.9271378936483,
47.5759636497553, 42.8384383202918, 65.7993659274524, 48.855811527642,
44.2095285986629, 68.9723599933059, 69.9544579544759, 44.8216289321835,
38.8006703371321, 50.2561078924462, 64.9956809553958, 72.807887313171,
57.0519839762712, 63.1276237688172, 69.4328601204522, 71.1677480218488,
52.2138858806112, 98.8694357422061, 80.9080605867082, 41.5734152282443,
75.2377203360134, 56.3080462160533, 46.8562902634998, 65.5689216055905,
40.6182551597062, 45.82270145776, 52.0609919483668, 70.4582000018096,
86.7278223300282, 96.7688241241398, 88.696292545466, 74.1631193504376,
60.9990309984499, 43.3906018065003, 60.3963424583717, 49.8045388132306,
59.8089509945327, 68.861572724206, 95.5985476138788, 69.8245712265719,
78.4535622451505, 85.7599366733162, 47.0205139472342, 39.2614725105802,
49.5929738672369, 66.4500861455891, 41.0920980793697, 97.5351854890994,
51.8832118207397, 92.1160608134408, 60.9913940274099, 43.3071730643006,
78.0316880201823, 96.227592967614, 73.0949980975804, 75.8584483127904,
72.3692519338389, 88.4758649955978, 75.8098595298246, 42.5084094357222,
42.7198785371646, 69.8037888983547, 45.6943068025075, 66.5893531774792,
59.5142819801296, 90.9601478974695, 85.5943071045201, 78.8447860014804,
90.4245389975396, 96.6474271688564, 60.7695052560259, 77.1591050907389,
87.508791764847, 35.5707034722887, 84.8451368493014, 45.3582836109166,
48.3802857972818, 87.1038509402546, 68.7754094720662, 64.9319380069486,
89.5298128951328, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0,
1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0,
0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1,
0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1,
0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1,
1), .Dim = c(100L, 3L))
}
# Begin Computation
input <- loadInput()
X <- cbind(1, input[, 1:2]) # Adding unit column vector
y <- cbind(input[, 3])
initial_theta <- matrix(0, nrow=3)
gradientDescent(X, y, initial_theta, control = list(maxit = 1000))
gradientDescent(X, y, initial_theta, method = "Nelder-Mead")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment