Skip to content

Instantly share code, notes, and snippets.

@medewitt
Created September 16, 2019 15:57
Show Gist options
  • Save medewitt/209512a8c042c3cb28b960a39cd28b4c to your computer and use it in GitHub Desktop.
Save medewitt/209512a8c042c3cb28b960a39cd28b4c to your computer and use it in GitHub Desktop.
Exploring Loss Functions and Integrating Over the Loss
# Parameters
set.seed(42)
n <- 100L
#true model
x <- rnorm(n, 5, 1)
treat <- rep(c(0,1), n/2)
y <- rnorm(n, 2 * treat + 1*x)
# Package data for stan
stan_dat <- list(
N = n,
x = x,
y = y,
status = treat
)
# compile and run model
library(rstan)
library(dplyr)
options(mc.cores = parallel::detectCores())
rstan_options(auto_write = TRUE)
# Run Model
linear_regression <- stan_model("loss-function.stan")
fit1 <- sampling(linear_regression, data = stan_dat,
chains = 2, iter = 1000, refresh = 0)
# Look at output
print(fit1)
results <- data.frame(
treat <-fit1 %>%
extract("treatment") %>%
as.vector(),
loss <- fit1 %>%
extract("loss") %>%
as.vector()
)
my_loss <- function(x){
if(x > 1){
1/log(x)
} else if (x >0 ) {
x
}else{
20
}
}
estimated_loss <- purrr::map_dbl(seq(-3, 3, .1), my_loss)
par(mfrow = c(1,4))
hist(results$treatment, main = "Histogram of Treatment Effect", col = "grey", breaks = 30)
plot(seq(-3, 3, .1), estimated_loss, main = "Loss Function")
plot(x = results$treatment, y = results$loss, main = "Treatment vs Loss")
hist(results$loss, breaks = 30, main = "Probable Loss")
// The input data is a vector 'y' of length 'N'.
functions{
/**
* loss_function
*
* @param x a vector of outputed values
*/
real loss_function(real x){
//Build output vector
real output;
if(x>1)
output = 1/log(x);
else if (x > 0 )
output = x;
else
output = 20;
return output;
}
}
data {
int<lower=0> N;
vector[N] x;
vector[N] status;
vector[N] y;
}
// The parameters accepted by the model. Our model
// accepts two parameters 'mu' and 'sigma'.
parameters {
real alpha;
real beta;
real treatment;
real<lower=0> sigma;
}
// The model to be estimated. We model the output
// 'y' to be normally distributed with mean 'mu'
// and standard deviation 'sigma'.
model {
y ~ normal(alpha + beta * x + treatment * status, sigma);
}
generated quantities{
real loss = loss_function(treatment);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment