Skip to content

Instantly share code, notes, and snippets.

@alphaville
Last active December 11, 2015 21:48
Show Gist options
  • Select an option

  • Save alphaville/4665099 to your computer and use it in GitHub Desktop.

Select an option

Save alphaville/4665099 to your computer and use it in GitHub Desktop.
x'Qx in C - Implementation for MATLAB - MEX file and compilation instructions
% SMDV_QUAD.m
disp('Compiling : SMDV_QUAD');
mex CFLAGS='$CFLAGS -Wall' -O -output smdv_quad -largeArrayDims mx_smdv_quad.c smdv_quad.c
% PCOST.m
disp('Compiling : PCOST');
mex CFLAGS='$CFLAGS -Wall' -O -output pcost -largeArrayDims pcost.c smdv_quad.c /usr/lib/libcblas.dylib
mex CFLAGS='$CFLAGS -Wall' -O -output dgrad -largeArrayDims dgrad.c smdv_quad.c vec_util.c /usr/lib/libcblas.dylib
/*
* File: mex_gpad_definitions.h
* Author: Pantelis Sopasakis (IMT Lucca)
*
* Created on January 30, 2013, 12:15 AM
*/
#ifndef MEX_GPAD_DEFINITIONS_H
#define MEX_GPAD_DEFINITIONS_H
#include <stdlib.h>
#define zreal double
#define zvect double*
#ifndef CBLAS_ENUM_DEFINED_H
#define CBLAS_ENUM_DEFINED_H
enum CBLAS_ORDER {
CblasRowMajor = 101, CblasColMajor = 102
};
enum CBLAS_TRANSPOSE {
CblasNoTrans = 111, CblasTrans = 112, CblasConjTrans = 113, AtlasConj = 114
};
enum CBLAS_UPLO {
CblasUpper = 121, CblasLower = 122
};
enum CBLAS_DIAG {
CblasNonUnit = 131, CblasUnit = 132
};
enum CBLAS_SIDE {
CblasLeft = 141, CblasRight = 142
};
#endif /* CBLAS_ENUM_DEFINED_H */
#ifndef CBLAS_DGEMV_EXTERNAL
#define CBLAS_DGEMV_EXTERNAL
/**
* Multiplies a matrix by a vector (double precision).
* This function multiplies A * X (after transposing A, if needed) and multiplies
* the resulting matrix by alpha. It then multiplies vector Y by beta.
* It stores the sum of these two products in vector Y.
*
* @param Order
* Specifies row-major (C) or column-major (Fortran) data ordering.
* @param TransA
* Specifies whether to transpose matrix A.
* @param M
* Number of rows in matrix A.
* @param N
* Number of columns in matrix A.
* @param alpha
* Scaling factor for the product of matrix A and vector X.
* @param A
* Matrix A.
* @param lda
* The size of the first dimention of matrix A; if you are passing a matrix A[m][n], the value should be m.
* @param X
* Vector X.
* @param incX
* Stride within X. For example, if incX is 7, every 7th element is used.
* @param beta
* Scaling factor for vector Y.
* @param Y
* Vector Y
* @param incY
* Stride within Y. For example, if incY is 7, every 7th element is used.
*/
extern void cblas_dgemv(const enum CBLAS_ORDER Order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const double alpha, const double *A, const int lda,
const double *X, const int incX, const double beta,
double *Y, const int incY) __OSX_AVAILABLE_STARTING(__MAC_10_2, __IPHONE_4_0);
#endif /* CBLAS_DGEMV_EXTERNAL */
#endif /* MEX_GPAD_DEFINITIONS_H */
#ifndef CBLAS_DSDOT_EXTERNAL
#define CBLAS_DSDOT_EXTERNAL
/**
* Computes the double-precision dot product of a pair of single-precision vectors.
* @param N
* The number of elements in the vectors.
* @param X
* Vector X.
* @param incX
* Stride within X. For example, if incX is 7, every 7th element is used.
* @param Y
* Vector Y.
* @param incY
* Stride within Y. For example, if incY is 7, every 7th element is used.
* @return
* The dot product x'y = y'x.
*/
extern double cblas_ddot(const int N, const double *X, const int incX,
const double *Y, const int incY) __OSX_AVAILABLE_STARTING(__MAC_10_2, __IPHONE_4_0);
#endif /* CBLAS_DSDOT_EXTERNAL */
#include "smdv_quad.h"
void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
// y = smdv_quad(Q,x)
check_smdv_quad_input(nrhs, prhs);
const mxArray *Q, *x;
double L;
double *output;
x = prhs[1];
Q = prhs[0];
const double *xVec = mxGetPr(x);
smdv_quad(Q, &xVec, &L);
plhs[0] = mxCreateDoubleMatrix(1, 1, mxREAL);
output = mxGetPr(plhs[0]);
output[0] = L;
}
#include "mex_gpad_definitions.h"
#include "smdv_quad.h"
static void check_pcost_input(int nrhs, const mxArray *Q, const mxArray *R, const mxArray *Pf, const mxArray *N, const mxArray *X, const mxArray *U) {
int Ndata;
double *NdataPtr;
NdataPtr = mxGetPr(N);
Ndata = (int) (*NdataPtr);
if (nrhs != 6) {
mexErrMsgIdAndTxt("MATLAB:pcost:numInputs", "PCOST admits 6 input parameters.");
}
if (!mxIsSparse(Q)) {
mexErrMsgIdAndTxt("MATLAB:pcost:badInput", "PCOST(Q,R,Pf,N,X,U): Q should be a sparse matrix");
}
if (!mxIsSparse(R)) {
mexErrMsgIdAndTxt("MATLAB:pcost:badInput", "PCOST(Q,R,Pf,N,X,U): R should be a sparse matrix");
}
if (mxGetN(N) != 1 || mxGetM(N) != 1) {
mexErrMsgIdAndTxt("MATLAB:pcost:badInput", "PCOST(Q,R,Pf,N,X,U): N should be a number - not a matrix!");
}
if (mxIsSparse(X)) {
mexErrMsgIdAndTxt("MATLAB:pcost:badInput", "PCOST(Q,R,Pf,N,X,U): X cannot be sparse");
}
if (mxIsSparse(U)) {
mexErrMsgIdAndTxt("MATLAB:pcost:badInput", "PCOST(Q,R,Pf,N,X,U): U cannot be sparse");
}
if (mxGetN(X) < Ndata + 1) {
mexErrMsgIdAndTxt("MATLAB:pcost:badInput", "PCOST(Q,R,Pf,N,X,U): The number of columns of X must be >= N+1");
}
if (mxGetN(U) < Ndata + 1) {
mexErrMsgIdAndTxt("MATLAB:pcost:badInput", "PCOST(Q,R,Pf,N,X,U): The number of columns of U must be >= N+1");
}
//TODO: Test more the compatibility
}
void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
const mxArray *Q = prhs[0], *R = prhs[1], *Pf = prhs[2], *N = prhs[3], *X = prhs[4], *U = prhs[5];
int k, nx, nu, Ndata;
double *NdataPtr, *ellx, *ellu, *Xdata, *Udata, *output, *Pfdata, J;
const double *currentState, *currentInput;
check_pcost_input(nrhs, Q, R, Pf, N, X, U);
Xdata = mxGetPr(X);
Udata = mxGetPr(U);
nx = (int) mxGetM(X);
nu = (int) mxGetM(U);
NdataPtr = mxGetPr(N);
Ndata = (int) (*NdataPtr);
currentState = Xdata;
currentInput = Udata;
ellx = malloc(sizeof (double));
ellu = malloc(sizeof (double));
J = 0;
for (k = 0; k < Ndata; k++) {
smdv_quad(Q, &currentState, ellx);
smdv_quad(R, &currentInput, ellu);
J += (*ellx) + (*ellu);
currentState += nx;
currentInput += nu;
}
if (mxIsSparse(Pf)) {
smdv_quad(Pf, &currentState, ellx);
} else {
double *z = calloc(nx,sizeof(double));// Clear allocation!
Pfdata = mxGetPr(Pf);
cblas_dgemv(CblasColMajor, CblasNoTrans, nx, nx,
1.0, Pfdata, nx, currentState, 1, 1.0, z, 1); // z = Pf * currentState
*ellx = cblas_ddot(nx, z, 1, currentState, 1); // ellx = currentState' * z = currentState' * Pf * currentState
free(z);
}
J += *ellx;
plhs[0] = mxCreateDoubleMatrix(1, 1, mxREAL);
output = mxGetPr(plhs[0]);
output[0] = 0.500 * J;
}
#include "smdv_quad.h"
void smdv_quad(const mxArray *Q, const zvect *x, zreal* y) {
const zreal *pQ, *px;
mwIndex *irQ, *jcQ;
mwSize col, total = 0;
mwIndex star_row_idx, stop_row_idx, current_row_index;
mwSize n;
/* Get the starting positions of all four data arrays. */
pQ = mxGetPr(Q);
irQ = mxGetIr(Q);
jcQ = mxGetJc(Q);
n = mxGetN(Q);
/* Retrieve the values of x */
px = *x;
/* Do le magic */
*y = 0;
for (col = 0; col < n; col++) { // Iterate over all columns
star_row_idx = jcQ[col];
stop_row_idx = jcQ[col + 1];
if (star_row_idx == stop_row_idx)
continue;
else {
for (current_row_index = star_row_idx; current_row_index < stop_row_idx; current_row_index++) {
(*y) += pQ[total++] * px[irQ[current_row_index]] * px[col];
}
}
}
}
void check_smdv_quad_input(int nrhs, const mxArray* prhs[]) {
if (nrhs != 2) {
mexErrMsgIdAndTxt("MATLAB:smdv_quad:numInputs", "SVDM_QUAD admits 2");
}
const mxArray *Q = prhs[0], *x = prhs[1];
size_t nQ = mxGetN(Q), mQ = mxGetM(Q), nx = mxGetN(prhs[1]), mx = mxGetM(prhs[1]);
if (!mxIsSparse(Q)) {
mexErrMsgIdAndTxt("MATLAB:smdv_quad:badInput", "SVDM_QUAD(Q,x): Q should be a sparse matrix");
}
if (nQ != mQ) {
mexErrMsgIdAndTxt("MATLAB:smdv_quad:badInput", "SVDM_QUAD(Q,x): Q should be a **square** sparse matrix");
}
if (!mxIsDouble(x) || nx != 1) {
mexErrMsgIdAndTxt("MATLAB:smdv_quad:badInput", "SVDM_QUAD(Q,x): x should be a vector!");
}
if (mx != nQ) {
mexErrMsgIdAndTxt("MATLAB:smdv_quad:badInput", "SVDM_QUAD(Q,x): Q and x should have compatible dimensions!");
}
}
/*
* File: smdvquad.h
* Author: Pantelis Sopasakis (IMT Lucca)
*
* Created on January 29, 2013, 11:08 PM
*/
#ifndef SMDVQUAD_H
#define SMDVQUAD_H
#ifdef __cplusplus
extern "C" {
#endif
#include <stddef.h>
#include "math.h"
#include "mex.h"
#include "mex_gpad_definitions.h"
void smdv_quad(const mxArray *Q, const zvect *x, zreal* y);
void check_smdv_quad_input(int nrhs, const mxArray* prhs[]);
#ifdef __cplusplus
}
#endif
#endif /* SMDVQUAD_H */
@alphaville
Copy link
Copy Markdown
Author

Consider this: /usr/lib/libcblas.dylib
And check this out: http://www.mathworks.co.uk/support/solutions/en/data/1-1B537/?solution=1-1B537

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment