Skip to content

Instantly share code, notes, and snippets.

@GaZ3ll3
Created November 21, 2015 04:58
Show Gist options
  • Save GaZ3ll3/39c9c3ef95e1737d2c65 to your computer and use it in GitHub Desktop.
Save GaZ3ll3/39c9c3ef95e1737d2c65 to your computer and use it in GitHub Desktop.
use VML/VSL/
/*
* armadillo_support.h
*
* Created on: Nov 20, 2015
* Author: lurker
*/
#ifndef SRC_ARMADILLO_SUPPORT_H_
#define SRC_ARMADILLO_SUPPORT_H_
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <armadillo>
#include "mkl.h"
using namespace arma;
namespace arma_support {
class rng_mkl {
public:
rng_mkl(int BRNG, int seed) {
_seed = seed;
_BRNG = BRNG;
vslNewStream(&_stream, BRNG, _seed);
}
~rng_mkl() {
vslDeleteStream(&_stream);
}
void setBRNG(int BRNG) {
_BRNG = BRNG;
}
template<typename eT>
inline void randi_fill(eT* mem, const uword N);
template<typename eT>
inline void randi_fill(eT* mem, const uword N, const int a, const int b);
template<typename eT>
inline void randu_fill(eT* mem, const uword N);
template<typename eT>
inline void randu_fill(eT* mem, const uword N, const eT a, const eT b);
template<typename eT>
inline void randn_fill(eT* mem, const uword N);
template<typename eT>
inline void randn_fill(eT* mem, const uword N, const eT a, const eT sigma);
template<typename eT>
inline void rande_fill(eT* mem, const uword N, const eT a, const eT beta);
template<typename eT>
inline void rande_fill(eT* mem, const uword N);
private:
VSLStreamStatePtr _stream;
int _seed;
int _BRNG;
};
class vml_mkl {
public:
template<typename eT>
inline static void exp(eT* mem, const uword N, const eT* a);
template<typename eT>
inline static void log(eT* mem, const uword N, const eT* a);
template<typename eT>
inline static void log10(eT* mem, const uword N, const eT* a);
template<typename eT>
inline static void sin(eT* mem, const uword N, const eT* a);
template<typename eT>
inline static void cos(eT* mem, const uword N, const eT* a);
template<typename eT>
inline static void tan(eT* mem, const uword N, const eT* a);
template<typename eT>
inline static void asin(eT* mem, const uword N, const eT* a);
template<typename eT>
inline static void acos(eT* mem, const uword N, const eT* a);
template<typename eT>
inline static void atan(eT* mem, const uword N, const eT* a);
};
template<typename eT>
inline void rng_mkl::randi_fill(eT* mem, const uword N) {
viRngUniform(_BRNG, _stream, N, mem, 0, std::numeric_limits<int>::max());
}
template<typename eT>
inline void rng_mkl::randi_fill(eT* mem, const uword N, const int a, const int b) {
viRngUniform(_BRNG, _stream, N, mem, a, b );
}
template<typename eT>
inline void rng_mkl::randu_fill(eT* mem, const uword N, const eT a,const eT b) {
if (is_float<eT>::value) {
typedef float T;
vsRngUniform( _BRNG, _stream, N, (T*)mem, a, b );
}
else if (is_double<eT>::value) {
typedef double T;
vdRngUniform( _BRNG, _stream, N, (T*)mem, a, b );
}
else {
return;
}
}
template<typename eT>
inline void rng_mkl::randu_fill(eT* mem, const uword N) {
randu_fill(mem, N, (eT)0, (eT)1);
}
template<typename eT>
inline void rng_mkl::randn_fill(eT* mem, const uword N, const eT a, const eT sigma) {
if(is_float<eT>::value) {
typedef float T;
vsRngGaussian(_BRNG , _stream, N, (T*)mem, a, sigma );
}
else if (is_double<eT>::value) {
typedef double T;
vdRngGaussian(_BRNG , _stream, N, (T*)mem, a, sigma );
}
else {
return;
}
}
template<typename eT>
inline void rng_mkl::randn_fill(eT* mem, const uword N) {
randn_fill(mem, N, (eT)0, (eT)1);
}
template<typename eT>
inline void rng_mkl::rande_fill(eT* mem, const uword N, const eT a, const eT beta) {
if(is_float<eT>::value) {
typedef float T;
vsRngExponential(_BRNG, _stream, N, (T*)mem, a, beta );
}
else if (is_double<eT>::value) {
typedef double T;
vdRngExponential(_BRNG, _stream, N, (T*)mem, a, beta );
}
else {
return;
}
}
template<typename eT>
inline void rng_mkl::rande_fill(eT* mem, const uword N) {
rande_fill(mem, N, (eT)0, (eT)1);
}
template<typename eT>
inline void vml_mkl::exp(eT* mem, const uword N, const eT* a) {
if(is_float<eT>::value) {
typedef float T;
vsExp(N, (T*)a,(T*)mem);
}
else if (is_double<eT>::value) {
typedef double T;
vdExp(N, (T*)a, (T*)mem);
}
else {
return;
}
}
template<typename eT>
inline void vml_mkl::log(eT* mem, const uword N, const eT* a) {
if(is_float<eT>::value) {
typedef float T;
vsLn(N, (T*)a,(T*)mem);
}
else if (is_double<eT>::value) {
typedef double T;
vdLn(N, (T*)a, (T*)mem);
}
else {
return;
}
}
template<typename eT>
inline void vml_mkl::log10(eT* mem, const uword N, const eT* a){
if(is_float<eT>::value) {
typedef float T;
vsLog10(N, (T*)a,(T*)mem);
}
else if (is_double<eT>::value) {
typedef double T;
vdLog10(N, (T*)a, (T*)mem);
}
else {
return;
}
}
template<typename eT>
inline void vml_mkl::sin(eT* mem, const uword N, const eT* a){
if(is_float<eT>::value) {
typedef float T;
vsSin(N, (T*)a,(T*)mem);
}
else if (is_double<eT>::value) {
typedef double T;
vdSin(N, (T*)a, (T*)mem);
}
else {
return;
}
}
template<typename eT>
inline void vml_mkl::cos(eT* mem, const uword N, const eT* a){
if(is_float<eT>::value) {
typedef float T;
vsCos(N, (T*)a,(T*)mem);
}
else if (is_double<eT>::value) {
typedef double T;
vdCos(N, (T*)a, (T*)mem);
}
else {
return;
}
}
template<typename eT>
inline void vml_mkl::tan(eT* mem, const uword N, const eT* a){
if(is_float<eT>::value) {
typedef float T;
vsTan(N, (T*)a,(T*)mem);
}
else if (is_double<eT>::value) {
typedef double T;
vdTan(N, (T*)a, (T*)mem);
}
else {
return;
}
}
template<typename eT>
inline void vml_mkl::asin(eT* mem, const uword N, const eT* a){
if(is_float<eT>::value) {
typedef float T;
vsAsin(N, (T*)a,(T*)mem);
}
else if (is_double<eT>::value) {
typedef double T;
vdAsin(N, (T*)a, (T*)mem);
}
else {
return;
}
}
template<typename eT>
inline void vml_mkl::acos(eT* mem, const uword N, const eT* a){
if(is_float<eT>::value) {
typedef float T;
vsAcos(N, (T*)a,(T*)mem);
}
else if (is_double<eT>::value) {
typedef double T;
vdAcos(N, (T*)a, (T*)mem);
}
else {
return;
}
}
template<typename eT>
inline void vml_mkl::atan(eT* mem, const uword N, const eT* a){
if(is_float<eT>::value) {
typedef float T;
vsAtan(N, (T*)a,(T*)mem);
}
else if (is_double<eT>::value) {
typedef double T;
vdAtan(N, (T*)a, (T*)mem);
}
else {
return;
}
}
}
#endif /* SRC_ARMADILLO_SUPPORT_H_ */
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment