Last active
December 11, 2015 21:48
-
-
Save alphaville/4665099 to your computer and use it in GitHub Desktop.
x'Qx in C - Implementation for MATLAB - MEX file and compilation instructions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| % SMDV_QUAD.m | |
| disp('Compiling : SMDV_QUAD'); | |
| mex CFLAGS='$CFLAGS -Wall' -O -output smdv_quad -largeArrayDims mx_smdv_quad.c smdv_quad.c | |
| % PCOST.m | |
| disp('Compiling : PCOST'); | |
| mex CFLAGS='$CFLAGS -Wall' -O -output pcost -largeArrayDims pcost.c smdv_quad.c /usr/lib/libcblas.dylib | |
| mex CFLAGS='$CFLAGS -Wall' -O -output dgrad -largeArrayDims dgrad.c smdv_quad.c vec_util.c /usr/lib/libcblas.dylib |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| /* | |
| * File: mex_gpad_definitions.h | |
| * Author: Pantelis Sopasakis (IMT Lucca) | |
| * | |
| * Created on January 30, 2013, 12:15 AM | |
| */ | |
| #ifndef MEX_GPAD_DEFINITIONS_H | |
| #define MEX_GPAD_DEFINITIONS_H | |
| #include <stdlib.h> | |
| #define zreal double | |
| #define zvect double* | |
| #ifndef CBLAS_ENUM_DEFINED_H | |
| #define CBLAS_ENUM_DEFINED_H | |
| enum CBLAS_ORDER { | |
| CblasRowMajor = 101, CblasColMajor = 102 | |
| }; | |
| enum CBLAS_TRANSPOSE { | |
| CblasNoTrans = 111, CblasTrans = 112, CblasConjTrans = 113, AtlasConj = 114 | |
| }; | |
| enum CBLAS_UPLO { | |
| CblasUpper = 121, CblasLower = 122 | |
| }; | |
| enum CBLAS_DIAG { | |
| CblasNonUnit = 131, CblasUnit = 132 | |
| }; | |
| enum CBLAS_SIDE { | |
| CblasLeft = 141, CblasRight = 142 | |
| }; | |
| #endif /* CBLAS_ENUM_DEFINED_H */ | |
| #ifndef CBLAS_DGEMV_EXTERNAL | |
| #define CBLAS_DGEMV_EXTERNAL | |
| /** | |
| * Multiplies a matrix by a vector (double precision). | |
| * This function multiplies A * X (after transposing A, if needed) and multiplies | |
| * the resulting matrix by alpha. It then multiplies vector Y by beta. | |
| * It stores the sum of these two products in vector Y. | |
| * | |
| * @param Order | |
| * Specifies row-major (C) or column-major (Fortran) data ordering. | |
| * @param TransA | |
| * Specifies whether to transpose matrix A. | |
| * @param M | |
| * Number of rows in matrix A. | |
| * @param N | |
| * Number of columns in matrix A. | |
| * @param alpha | |
| * Scaling factor for the product of matrix A and vector X. | |
| * @param A | |
| * Matrix A. | |
| * @param lda | |
| * The size of the first dimention of matrix A; if you are passing a matrix A[m][n], the value should be m. | |
| * @param X | |
| * Vector X. | |
| * @param incX | |
| * Stride within X. For example, if incX is 7, every 7th element is used. | |
| * @param beta | |
| * Scaling factor for vector Y. | |
| * @param Y | |
| * Vector Y | |
| * @param incY | |
| * Stride within Y. For example, if incY is 7, every 7th element is used. | |
| */ | |
| extern void cblas_dgemv(const enum CBLAS_ORDER Order, | |
| const enum CBLAS_TRANSPOSE TransA, const int M, const int N, | |
| const double alpha, const double *A, const int lda, | |
| const double *X, const int incX, const double beta, | |
| double *Y, const int incY) __OSX_AVAILABLE_STARTING(__MAC_10_2, __IPHONE_4_0); | |
| #endif /* CBLAS_DGEMV_EXTERNAL */ | |
| #endif /* MEX_GPAD_DEFINITIONS_H */ | |
| #ifndef CBLAS_DSDOT_EXTERNAL | |
| #define CBLAS_DSDOT_EXTERNAL | |
| /** | |
| * Computes the double-precision dot product of a pair of single-precision vectors. | |
| * @param N | |
| * The number of elements in the vectors. | |
| * @param X | |
| * Vector X. | |
| * @param incX | |
| * Stride within X. For example, if incX is 7, every 7th element is used. | |
| * @param Y | |
| * Vector Y. | |
| * @param incY | |
| * Stride within Y. For example, if incY is 7, every 7th element is used. | |
| * @return | |
| * The dot product x'y = y'x. | |
| */ | |
| extern double cblas_ddot(const int N, const double *X, const int incX, | |
| const double *Y, const int incY) __OSX_AVAILABLE_STARTING(__MAC_10_2, __IPHONE_4_0); | |
| #endif /* CBLAS_DSDOT_EXTERNAL */ | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #include "smdv_quad.h" | |
| void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { | |
| // y = smdv_quad(Q,x) | |
| check_smdv_quad_input(nrhs, prhs); | |
| const mxArray *Q, *x; | |
| double L; | |
| double *output; | |
| x = prhs[1]; | |
| Q = prhs[0]; | |
| const double *xVec = mxGetPr(x); | |
| smdv_quad(Q, &xVec, &L); | |
| plhs[0] = mxCreateDoubleMatrix(1, 1, mxREAL); | |
| output = mxGetPr(plhs[0]); | |
| output[0] = L; | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #include "mex_gpad_definitions.h" | |
| #include "smdv_quad.h" | |
| static void check_pcost_input(int nrhs, const mxArray *Q, const mxArray *R, const mxArray *Pf, const mxArray *N, const mxArray *X, const mxArray *U) { | |
| int Ndata; | |
| double *NdataPtr; | |
| NdataPtr = mxGetPr(N); | |
| Ndata = (int) (*NdataPtr); | |
| if (nrhs != 6) { | |
| mexErrMsgIdAndTxt("MATLAB:pcost:numInputs", "PCOST admits 6 input parameters."); | |
| } | |
| if (!mxIsSparse(Q)) { | |
| mexErrMsgIdAndTxt("MATLAB:pcost:badInput", "PCOST(Q,R,Pf,N,X,U): Q should be a sparse matrix"); | |
| } | |
| if (!mxIsSparse(R)) { | |
| mexErrMsgIdAndTxt("MATLAB:pcost:badInput", "PCOST(Q,R,Pf,N,X,U): R should be a sparse matrix"); | |
| } | |
| if (mxGetN(N) != 1 || mxGetM(N) != 1) { | |
| mexErrMsgIdAndTxt("MATLAB:pcost:badInput", "PCOST(Q,R,Pf,N,X,U): N should be a number - not a matrix!"); | |
| } | |
| if (mxIsSparse(X)) { | |
| mexErrMsgIdAndTxt("MATLAB:pcost:badInput", "PCOST(Q,R,Pf,N,X,U): X cannot be sparse"); | |
| } | |
| if (mxIsSparse(U)) { | |
| mexErrMsgIdAndTxt("MATLAB:pcost:badInput", "PCOST(Q,R,Pf,N,X,U): U cannot be sparse"); | |
| } | |
| if (mxGetN(X) < Ndata + 1) { | |
| mexErrMsgIdAndTxt("MATLAB:pcost:badInput", "PCOST(Q,R,Pf,N,X,U): The number of columns of X must be >= N+1"); | |
| } | |
| if (mxGetN(U) < Ndata + 1) { | |
| mexErrMsgIdAndTxt("MATLAB:pcost:badInput", "PCOST(Q,R,Pf,N,X,U): The number of columns of U must be >= N+1"); | |
| } | |
| //TODO: Test more the compatibility | |
| } | |
| void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { | |
| const mxArray *Q = prhs[0], *R = prhs[1], *Pf = prhs[2], *N = prhs[3], *X = prhs[4], *U = prhs[5]; | |
| int k, nx, nu, Ndata; | |
| double *NdataPtr, *ellx, *ellu, *Xdata, *Udata, *output, *Pfdata, J; | |
| const double *currentState, *currentInput; | |
| check_pcost_input(nrhs, Q, R, Pf, N, X, U); | |
| Xdata = mxGetPr(X); | |
| Udata = mxGetPr(U); | |
| nx = (int) mxGetM(X); | |
| nu = (int) mxGetM(U); | |
| NdataPtr = mxGetPr(N); | |
| Ndata = (int) (*NdataPtr); | |
| currentState = Xdata; | |
| currentInput = Udata; | |
| ellx = malloc(sizeof (double)); | |
| ellu = malloc(sizeof (double)); | |
| J = 0; | |
| for (k = 0; k < Ndata; k++) { | |
| smdv_quad(Q, ¤tState, ellx); | |
| smdv_quad(R, ¤tInput, ellu); | |
| J += (*ellx) + (*ellu); | |
| currentState += nx; | |
| currentInput += nu; | |
| } | |
| if (mxIsSparse(Pf)) { | |
| smdv_quad(Pf, ¤tState, ellx); | |
| } else { | |
| double *z = calloc(nx,sizeof(double));// Clear allocation! | |
| Pfdata = mxGetPr(Pf); | |
| cblas_dgemv(CblasColMajor, CblasNoTrans, nx, nx, | |
| 1.0, Pfdata, nx, currentState, 1, 1.0, z, 1); // z = Pf * currentState | |
| *ellx = cblas_ddot(nx, z, 1, currentState, 1); // ellx = currentState' * z = currentState' * Pf * currentState | |
| free(z); | |
| } | |
| J += *ellx; | |
| plhs[0] = mxCreateDoubleMatrix(1, 1, mxREAL); | |
| output = mxGetPr(plhs[0]); | |
| output[0] = 0.500 * J; | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #include "smdv_quad.h" | |
| void smdv_quad(const mxArray *Q, const zvect *x, zreal* y) { | |
| const zreal *pQ, *px; | |
| mwIndex *irQ, *jcQ; | |
| mwSize col, total = 0; | |
| mwIndex star_row_idx, stop_row_idx, current_row_index; | |
| mwSize n; | |
| /* Get the starting positions of all four data arrays. */ | |
| pQ = mxGetPr(Q); | |
| irQ = mxGetIr(Q); | |
| jcQ = mxGetJc(Q); | |
| n = mxGetN(Q); | |
| /* Retrieve the values of x */ | |
| px = *x; | |
| /* Do le magic */ | |
| *y = 0; | |
| for (col = 0; col < n; col++) { // Iterate over all columns | |
| star_row_idx = jcQ[col]; | |
| stop_row_idx = jcQ[col + 1]; | |
| if (star_row_idx == stop_row_idx) | |
| continue; | |
| else { | |
| for (current_row_index = star_row_idx; current_row_index < stop_row_idx; current_row_index++) { | |
| (*y) += pQ[total++] * px[irQ[current_row_index]] * px[col]; | |
| } | |
| } | |
| } | |
| } | |
| void check_smdv_quad_input(int nrhs, const mxArray* prhs[]) { | |
| if (nrhs != 2) { | |
| mexErrMsgIdAndTxt("MATLAB:smdv_quad:numInputs", "SVDM_QUAD admits 2"); | |
| } | |
| const mxArray *Q = prhs[0], *x = prhs[1]; | |
| size_t nQ = mxGetN(Q), mQ = mxGetM(Q), nx = mxGetN(prhs[1]), mx = mxGetM(prhs[1]); | |
| if (!mxIsSparse(Q)) { | |
| mexErrMsgIdAndTxt("MATLAB:smdv_quad:badInput", "SVDM_QUAD(Q,x): Q should be a sparse matrix"); | |
| } | |
| if (nQ != mQ) { | |
| mexErrMsgIdAndTxt("MATLAB:smdv_quad:badInput", "SVDM_QUAD(Q,x): Q should be a **square** sparse matrix"); | |
| } | |
| if (!mxIsDouble(x) || nx != 1) { | |
| mexErrMsgIdAndTxt("MATLAB:smdv_quad:badInput", "SVDM_QUAD(Q,x): x should be a vector!"); | |
| } | |
| if (mx != nQ) { | |
| mexErrMsgIdAndTxt("MATLAB:smdv_quad:badInput", "SVDM_QUAD(Q,x): Q and x should have compatible dimensions!"); | |
| } | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| /* | |
| * File: smdvquad.h | |
| * Author: Pantelis Sopasakis (IMT Lucca) | |
| * | |
| * Created on January 29, 2013, 11:08 PM | |
| */ | |
| #ifndef SMDVQUAD_H | |
| #define SMDVQUAD_H | |
| #ifdef __cplusplus | |
| extern "C" { | |
| #endif | |
| #include <stddef.h> | |
| #include "math.h" | |
| #include "mex.h" | |
| #include "mex_gpad_definitions.h" | |
| void smdv_quad(const mxArray *Q, const zvect *x, zreal* y); | |
| void check_smdv_quad_input(int nrhs, const mxArray* prhs[]); | |
| #ifdef __cplusplus | |
| } | |
| #endif | |
| #endif /* SMDVQUAD_H */ | |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Consider this: /usr/lib/libcblas.dylib
And check this out: http://www.mathworks.co.uk/support/solutions/en/data/1-1B537/?solution=1-1B537