Created
September 16, 2019 15:57
-
-
Save medewitt/209512a8c042c3cb28b960a39cd28b4c to your computer and use it in GitHub Desktop.
Exploring Loss Functions and Integrating Over the Loss
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Parameters | |
set.seed(42) | |
n <- 100L | |
#true model | |
x <- rnorm(n, 5, 1) | |
treat <- rep(c(0,1), n/2) | |
y <- rnorm(n, 2 * treat + 1*x) | |
# Package data for stan | |
stan_dat <- list( | |
N = n, | |
x = x, | |
y = y, | |
status = treat | |
) | |
# compile and run model | |
library(rstan) | |
library(dplyr) | |
options(mc.cores = parallel::detectCores()) | |
rstan_options(auto_write = TRUE) | |
# Run Model | |
linear_regression <- stan_model("loss-function.stan") | |
fit1 <- sampling(linear_regression, data = stan_dat, | |
chains = 2, iter = 1000, refresh = 0) | |
# Look at output | |
print(fit1) | |
results <- data.frame( | |
treat <-fit1 %>% | |
extract("treatment") %>% | |
as.vector(), | |
loss <- fit1 %>% | |
extract("loss") %>% | |
as.vector() | |
) | |
my_loss <- function(x){ | |
if(x > 1){ | |
1/log(x) | |
} else if (x >0 ) { | |
x | |
}else{ | |
20 | |
} | |
} | |
estimated_loss <- purrr::map_dbl(seq(-3, 3, .1), my_loss) | |
par(mfrow = c(1,4)) | |
hist(results$treatment, main = "Histogram of Treatment Effect", col = "grey", breaks = 30) | |
plot(seq(-3, 3, .1), estimated_loss, main = "Loss Function") | |
plot(x = results$treatment, y = results$loss, main = "Treatment vs Loss") | |
hist(results$loss, breaks = 30, main = "Probable Loss") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// The input data is a vector 'y' of length 'N'. | |
functions{ | |
/** | |
* loss_function | |
* | |
* @param x a vector of outputed values | |
*/ | |
real loss_function(real x){ | |
//Build output vector | |
real output; | |
if(x>1) | |
output = 1/log(x); | |
else if (x > 0 ) | |
output = x; | |
else | |
output = 20; | |
return output; | |
} | |
} | |
data { | |
int<lower=0> N; | |
vector[N] x; | |
vector[N] status; | |
vector[N] y; | |
} | |
// The parameters accepted by the model. Our model | |
// accepts two parameters 'mu' and 'sigma'. | |
parameters { | |
real alpha; | |
real beta; | |
real treatment; | |
real<lower=0> sigma; | |
} | |
// The model to be estimated. We model the output | |
// 'y' to be normally distributed with mean 'mu' | |
// and standard deviation 'sigma'. | |
model { | |
y ~ normal(alpha + beta * x + treatment * status, sigma); | |
} | |
generated quantities{ | |
real loss = loss_function(treatment); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment