Created
August 17, 2012 08:22
-
-
Save jrnold/3376975 to your computer and use it in GitHub Desktop.
compiling stan model as shared object
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Code generated by Stan version 1.0 | |
#include <stan/model/model_header.hpp> | |
namespace foo_namespace { | |
using std::vector; | |
using std::string; | |
using std::stringstream; | |
using stan::agrad::var; | |
using stan::model::prob_grad_ad; | |
using stan::math::get_base1; | |
using stan::io::dump; | |
using std::istream; | |
using namespace stan::math; | |
using namespace stan::prob; | |
using namespace stan::agrad; | |
typedef Eigen::Matrix<double,Eigen::Dynamic,1> vector_d; | |
typedef Eigen::Matrix<double,1,Eigen::Dynamic> row_vector_d; | |
typedef Eigen::Matrix<double,Eigen::Dynamic,Eigen::Dynamic> matrix_d; | |
typedef Eigen::Matrix<stan::agrad::var,Eigen::Dynamic,1> vector_v; | |
typedef Eigen::Matrix<stan::agrad::var,1,Eigen::Dynamic> row_vector_v; | |
typedef Eigen::Matrix<stan::agrad::var,Eigen::Dynamic,Eigen::Dynamic> matrix_v; | |
class foo : public prob_grad_ad { | |
private: | |
public: | |
foo(stan::io::var_context& context__) | |
: prob_grad_ad::prob_grad_ad(0) { | |
static const char* function__ = "foo_namespace::foo(%1%)"; | |
size_t pos__; | |
std::vector<int> vals_i__; | |
std::vector<double> vals_r__; | |
// validate data | |
// validate transformed data | |
set_param_ranges(); | |
} // dump ctor | |
void set_param_ranges() { | |
num_params_r__ = 0U; | |
param_ranges_i__.clear(); | |
++num_params_r__; | |
} | |
void transform_inits(const stan::io::var_context& var_context__, | |
std::vector<int>& params_i__, | |
std::vector<double>& params_r__) { | |
params_r__.clear(); | |
params_i__.clear(); | |
stan::io::writer<double> writer__(params_r__,params_i__); | |
size_t pos__; | |
std::vector<double> vals_r__; | |
std::vector<int> vals_i__; | |
if (!(var_context__.contains_r("y"))) | |
throw std::runtime_error("variable y missing"); | |
if (var_context__.dims_r("y").size() != 0) | |
throw std::runtime_error("require 0 dimensions for variable y"); | |
vals_r__ = var_context__.vals_r("y"); | |
pos__ = 0U; | |
double y(0); | |
y = vals_r__[pos__++]; | |
writer__.scalar_unconstrain(y); | |
params_r__ = writer__.data_r(); | |
params_i__ = writer__.data_i(); | |
} | |
var log_prob(vector<var>& params_r__, | |
vector<int>& params_i__) { | |
var lp__(0.0); | |
// model parameters | |
stan::io::reader<var> in__(params_r__,params_i__); | |
var y = in__.scalar_constrain(lp__); | |
// transformed parameters | |
// validate transformed parameters | |
// model body | |
lp__ += stan::prob::normal_log<true>(y, 0, 1); | |
return lp__; | |
} // log_prob() | |
void get_param_names(std::vector<std::string>& names__) { | |
names__.resize(0); | |
names__.push_back("y"); | |
} | |
void get_dims(std::vector<std::vector<size_t> >& dimss__) { | |
dimss__.resize(0); | |
std::vector<size_t> dims__; | |
dims__.resize(0); | |
dimss__.push_back(dims__); | |
} | |
void write_array(std::vector<double>& params_r__, | |
std::vector<int>& params_i__, | |
std::vector<double>& vars__) { | |
vars__.resize(0); | |
stan::io::reader<double> in__(params_r__,params_i__); | |
static const char* function__ = "foo_namespace::write_array(%1%)"; | |
// read-transform, write parameters | |
double y = in__.scalar_constrain(); | |
vars__.push_back(y); | |
// declare and define transformed parameters | |
double lp__ = 0.0; | |
// validate transformed parameters | |
// write transformed parameters | |
// declare and define generated quantities | |
// validate generated quantities | |
// write generated quantities | |
} | |
void write_csv_header(std::ostream& o__) { | |
stan::io::csv_writer writer__(o__); | |
writer__.comma(); | |
o__ << "y"; | |
writer__.newline(); | |
} | |
void write_csv(std::vector<double>& params_r__, | |
std::vector<int>& params_i__, | |
std::ostream& o__) { | |
stan::io::reader<double> in__(params_r__,params_i__); | |
stan::io::csv_writer writer__(o__); | |
static const char* function__ = "foo_namespace::write_csv(%1%)"; | |
// read-transform, write parameters | |
double y = in__.scalar_constrain(); | |
writer__.write(y); | |
// declare, define and validate transformed parameters | |
double lp__ = 0.0; | |
// write transformed parameters | |
// declare and define generated quantities | |
// validate generated quantities | |
// write generated quantities | |
writer__.newline(); | |
} | |
}; // model | |
} // namespace | |
int main(int argc, const char* argv[]) { | |
try { | |
stan::gm::nuts_command<foo_namespace::foo>(argc,argv); | |
} catch (std::exception& e) { | |
std::cerr << std::endl << "Exception: " << e.what() << std::endl; | |
std::cerr << "Diagnostic information: " << std::endl << boost::diagnostic_information(e) << std::endl; | |
return -1; | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
parameters { | |
real y; | |
} | |
model { | |
y ~ normal(0, 1); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(rstan) | |
library(stringr) | |
stan_shared <- function(cppfile, outfile, verbose=TRUE) { | |
cppcode <- readLines(cppfile) | |
## get model name from the cpp code | |
## Assumes only 1 class in the cpp file | |
model.name <- na.omit(str_match(cppcode, "class\\s+(.*?)\\s+:"))[ , 2] | |
## Patch cpp code to be used with Rcpp | |
newcppcode <- | |
paste("#include <rstan/rstaninc.hpp>", | |
paste(cppcode, collapse="\n"), | |
rstan:::get_Rcpp_module_def_code(model.name), | |
sep = "\n") | |
## Write out new cpp code to a tempfile because | |
newcppfile <- tempfile(fileext=".cpp") | |
writeLines(newcppcode, newcppfile) | |
## Before compiling Set all the environment variables for linking, | |
## etc. This is what inline:::cxxfunction appears to do | |
settings <- getPlugin("rstan") | |
### Copied from Rcpp::cxxfunction | |
if (!is.null(env <- settings$env)) { | |
do.call(Sys.setenv, env) | |
if (isTRUE(verbose)) { | |
cat(" >> setting environment variables: \n") | |
writeLines(sprintf("%s = %s", names(env), env)) | |
} | |
} | |
LinkingTo <- settings$LinkingTo | |
if (!is.null(LinkingTo)) { | |
paths <- .find.package(LinkingTo, quiet = TRUE) | |
if (length(paths)) { | |
flag <- paste(paste0("-I\"", paths, "/include\""), | |
collapse = " ") | |
Sys.setenv(CLINK_CPPFLAGS = flag) | |
if (isTRUE(verbose)) { | |
cat(sprintf("\n >> LinkingTo : %s\n", paste(LinkingTo, | |
collapse = ", "))) | |
cat("CLINK_CPPFLAGS = ", flag, "\n\n") | |
} | |
} | |
} | |
### End of cxxfunction copied | |
## Also see inline:::compileCode for some platform indep hacks. | |
## Use R CMD SHLIB to compile | |
## Could also do R CMD COMPILE and then R CMD SHLIB | |
cmd <- sprintf("-o %s %s", shQuote(outfile), shQuote(newcppfile)) | |
R <- file.path(R.home(component = "bin"), "R") | |
system2(R, c("CMD", "SHLIB", | |
sprintf("-o %s", shQuote(outfile)), | |
shQuote(newcppfile))) | |
} | |
## Compile foo.cpp into a shared object foo.so | |
## This will need to be generalized for cross platform | |
stan_shared("foo.cpp", "foo.so") | |
## Load the so | |
dyn.load('foo.so') | |
## foo.so should appear | |
getLoadedDLLs() | |
## See section 3.3 of "Exposing C++ functions and classes with Rcpp modules" | |
## Load Rcpp Module | |
mod <- Module("foo", getDynLib('foo')) | |
foo <- `$`(mod, "foo") | |
stanmodel_object <- new(foo) | |
##' List classes within a module | |
##' see getMethods("show", "Module") from which this was extracted | |
##' as far as I know, there is no public facing code to view the classes / functions in | |
##' a Module | |
module_classes <- function(object) { | |
pointer <- Rcpp:::.getModulePointer(object, FALSE) | |
if (identical(pointer, Rcpp:::.badModulePointer)) { | |
object <- as.environment(object) | |
txt <- sprintf("Uninitialized module named \"%s\" from package \"%s\"", | |
get("moduleName", envir = object), get("packageName", | |
envir = object)) | |
writeLines(txt) | |
} else { | |
info <- .Call(Rcpp:::Module__classes_info, pointer) | |
names(info) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment