Skip to content

Instantly share code, notes, and snippets.

@Planeshifter
Last active August 29, 2015 14:01
Show Gist options
  • Save Planeshifter/f8582c7e6047f0461ca2 to your computer and use it in GitHub Desktop.
Save Planeshifter/f8582c7e6047f0461ca2 to your computer and use it in GitHub Desktop.
Multinomial Logistic Regression with Regularization
#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]
#include <iostream>
using namespace Rcpp;
using namespace arma;
using namespace std;
class MultLogit {
public:
MultLogit(NumericMatrix X_input, NumericMatrix Y_input, double sigma_input, int T_input);
int K; // number of classes
int p; // number of features
int n_obs; // number of observations
int T; // numer of iterations
double sigma; // prior sd
arma::mat X;
arma::mat Y;
arma::mat Y_hat;
arma::mat Beta;
arma::mat logLikGradientMatrix;
void coordinate_descent();
double softmaxProb(int i, int k);
double tentative_step(int j, int k);
void logLikGradient_LA();
void predict();
private:
};
// initilization function for regularized multinomial regression class
// Arguments:
MultLogit::MultLogit(NumericMatrix X_input, NumericMatrix Y_input, double sigma_input, int T_input)
{
n_obs = X_input.nrow(), p = X_input.ncol();
K = Y_input.ncol();
T = T_input;
sigma = sigma_input;
X = as<arma::mat>(X_input);
Y = as<arma::mat>(Y_input);
arma::mat Beta_new(K,p);
Beta_new.zeros();
Beta = Beta_new;
}
// Calculates the probability for obs i to belong to class k
double MultLogit::softmaxProb(int i, int k){
double numerator;
// Rcout << "i:" << i << " and k: " << k << "\n";
arma::rowvec beta_k = Beta.row(k);
arma::rowvec x_i = X.row(i);
numerator = exp(dot(beta_k,x_i));
double denominator = 0;
for(int l=0;l<K;l++)
{
arma::rowvec beta_l = Beta.row(l);
denominator += exp(dot(beta_l,x_i));
}
return numerator / denominator;
}
void MultLogit::predict(){
arma::mat predicted(n_obs,K);
for (int i = 0; i < n_obs; i++)
{
for (int k=0; k < K; k++)
predicted(i,k) = softmaxProb(i,k);
}
Y_hat = predicted;
}
void MultLogit::logLikGradient_LA(){
arma::mat gradient_sum(p,K);
gradient_sum.zeros();
for (int i = 0; i < n_obs; i++)
{
arma::rowvec p_i(K);
arma::mat x_i = X.row(i).t();
for (int k=0; k < K; k++)
{
p_i(k) = softmaxProb(i, k);
}
arma::mat y_i = Y.row(i);
//Rcout << y_i;
//Rcout << "\n" << x_i;
arma::mat KRON(K,p);
KRON = kron(y_i - p_i, x_i);
gradient_sum += KRON;
}
logLikGradientMatrix = gradient_sum.t();
}
double MultLogit::tentative_step(int k, int j){
double sum_X_squared = 0;
double suggested_step = 0;
for(int i=0;i<n_obs;i++)
{
sum_X_squared += pow(X(i,j),2.0);
}
//Rcout << "Sum_X_Squared: " << sum_X_squared << "\n";
double Q_kj = ((double)K-1)/(2 * (double)K) * sum_X_squared;
//Rcout << "Q_kj: " << Q_kj << "\n";
logLikGradient_LA();
double logLikGradient_kj = logLikGradientMatrix(k,j);
double Beta_kj = Beta(k,j);
double step_numerator = logLikGradient_kj - (double)2 * Beta_kj / sigma;
double step_denominator = Q_kj + (double)2 / sigma;
//Rcout << "Denominator: " << step_denominator;
// Rcout << step_numerator << "/" << step_denominator;
suggested_step = step_numerator / step_denominator;
return suggested_step;
}
void MultLogit::coordinate_descent() {
Rcpp::Rcout << "Coordinate Descent \n";
for (int t=0;t<T;t++)
{
for(int j=0;j<p;j++)
{
for(int k=0;k<K-1;k++)
{
double dnu_kj = tentative_step(k,j);
// Rcpp::Rcout << "| Step:" << dnu_kj << "| ";
// Rcout << "j: " << j << " k: " << k << "| ";
Beta(k,j) += dnu_kj;
// Rcout << "Beta(k,j): " << Beta(k,j) << "| ";
}
}
}
};
RCPP_MODULE(yada){
class_<MultLogit>( "MultLogit" )
.constructor<NumericMatrix,NumericMatrix,double>()
.field("K",&MultLogit::K)
.field("n_obs",&MultLogit::n_obs)
.field("p",&MultLogit::p)
.field("X",&MultLogit::X)
.field("Y",&MultLogit::Y)
.field("Y_hat",&MultLogit::Y_hat)
.field("Beta",&MultLogit::Beta)
.field("Delta",&MultLogit::Delta)
.field("logLikGradientMatrix",&MultLogit::logLikGradientMatrix)
.method("coordinate_descent",&MultLogit::coordinate_descent)
.method("tentative_step",&MultLogit::tentative_step)
.method("softmaxProb",&MultLogit::softmaxProb)
.method("logLikGradient",&MultLogit::logLikGradient)
.method("logLikGradient_LA",&MultLogit::logLikGradient_LA)
.method("predict",&MultLogit::predict);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment