Skip to content

Instantly share code, notes, and snippets.

@Dpananos
Created October 23, 2020 02:28
Show Gist options
  • Save Dpananos/5f9c026d3b21ec53638f6ce067c20184 to your computer and use it in GitHub Desktop.
Save Dpananos/5f9c026d3b21ec53638f6ce067c20184 to your computer and use it in GitHub Desktop.
Simulation for R squared CI coverage
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from scipy.special import expit, logit
from itertools import product
import pandas as pd
import seaborn as sns
def make_regression_data(n, alpha, sigma):
x = np.random.normal(size = n)
X = x.reshape(-1,1)
y = alpha*x + np.random.normal(0, sigma, size = n)
return (X, y)
def interval(ytest, ypred):
residuals = (ytest - ypred)
squared_error = np.power(residuals, 2)
ci = squared_error.mean() + np.array([-1.96,1.96]) * squared_error.std(ddof=1) / np.sqrt(squared_error.size)
ci = ci[::-1]
return 1 - ci/np.var(ytest)
def do_fit(n, alpha, sigma):
X, y = make_regression_data(n, alpha, sigma)
model = LinearRegression()
model.fit(X,y)
ypred = model.predict(X)
ci = interval(y, ypred)
return ci
def experiment(R2, n):
sigma = 1
alpha = sigma * np.sqrt( 1/(1-R2) -1 )
num_sims = 5000
confidence_intervals = np.zeros((num_sims,2))
for i in range(num_sims):
ci = do_fit(n, alpha, sigma)
confidence_intervals[i] = ci
lower_limit, upper_limit = confidence_intervals.T
coverage = np.mean((lower_limit<R2)&(upper_limit>R2))
realistic = np.mean((lower_limit<0)|(upper_limit>1))
return coverage, realistic
R2_values = np.arange(0.01, 0.99, 0.1)
sample_sizes = [50, 100, 250, 1000, 10000]
params = list(product(R2_values, sample_sizes))
results = [experiment(*p) for p in params]
df = pd.DataFrame(params, columns = ['R2','sample_size'])
df['R2'] = df.R2.round(2)
df['coverage'] = [x[0] for x in results]
df['realistic'] = [x[1] for x in results]
pivoted_coverage = df.pivot('R2','sample_size','coverage')
pivoted_realistic = df.pivot('R2','sample_size','realistic')
fig, ax = plt.subplots(dpi = 240)
sns.heatmap(pivoted_coverage, square = True, cmap = 'RdBu_r', center = 0.95, ax = ax)
plt.tight_layout()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment