Created
September 1, 2018 16:08
-
-
Save javipus/2bce7098b5f70ed80300c02dac99d468 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import os | |
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
""" | |
Understanding upper bounds for Brier scores through random predictors. | |
""" | |
class Data(object): | |
""" | |
Data generating process | |
""" | |
def __init__(self, model = 'uniform', **kwds): | |
""" | |
Define model. | |
""" | |
models = { | |
'centrist': lambda size: np.full(size, .5), | |
'uniform': np.random.uniform, | |
'normal': np.random.normal, | |
'gamma': np.random.gamma, | |
'beta': np.random.beta | |
} | |
self.model = model | |
self._model = lambda N: models[self.model](size = N, **kwds).clip(min = 0, max = 1) | |
def generate(self, N = 1): | |
""" | |
Generate events. | |
""" | |
self.p = self._model(N = N) | |
self.events = np.array(list(map(lambda _p: np.random.choice([0, 1], p = [_p, 1-_p]), self.p))) | |
return self.events | |
class Predictor(object): | |
""" | |
Prediction generator. | |
""" | |
def __init__(self, model, **kwds): | |
""" | |
Define model. | |
""" | |
# TODO non random models that actually use (some of the information in) the data for prediction | |
models = { | |
'centrist': lambda size: np.full(size, .5), | |
'uniform': np.random.uniform, | |
'normal': np.random.normal, | |
'gamma': np.random.gamma, | |
'beta': np.random.beta | |
} | |
self.model = model | |
self._model = lambda N: models[self.model](size = N, **kwds).clip(min = 0, max = 1) | |
def predict(self, data, N = 1): | |
""" | |
Generate events. | |
""" | |
# Ignore data | |
p = self._model(N = N) | |
return p | |
def score(self, events, predictions): | |
""" | |
Brier score. | |
""" | |
if len(events) != len(predictions): | |
raise ValueError('Length mismatch: {} events != {} predictions'.format(len(eents), len(predictions))) | |
N = len(events) | |
score = sum((events - predictions)**2)/N | |
return score | |
def main(): | |
""" | |
Run pairwise comparisons. | |
""" | |
N = int(1e5) | |
SAVEFILE = os.path.join(os.path.dirname(__file__), 'results.csv') | |
MODELS = { | |
'centrist': {}, | |
'uniform': {}, | |
'normal': {'loc': .5, 'scale': .5/3}, | |
'gamma': {'shape': .5, 'scale': .5}, | |
'beta': {'a': .5, 'b': .5}, | |
} | |
scores = pd.DataFrame(None, index = MODELS, columns = MODELS) | |
scores.index.name = 'Data' | |
scores.columns.name = 'Predictor' | |
for dataModel, dataKwds in MODELS.items(): | |
data = Data(dataModel, **dataKwds) | |
print('\nPitting a {}({}) universe against...'.format(dataModel, ', '.join(['{} = {}'.format(k, val) for k, val in dataKwds.items()] if dataKwds else ''))) | |
for predictorModel, predictorKwds in MODELS.items(): | |
t0 = time.time() | |
print(' a {}({}) predictor'.format(predictorModel, ', '.join(['{} = {}'.format(k, val) for k, val in predictorKwds.items()] if predictorKwds else ''))) | |
predictor = Predictor(predictorModel, **predictorKwds) | |
events = data.generate(N) | |
predictions = predictor.predict(data, N) | |
thisScore = predictor.score(events, predictions) | |
scores.loc[dataModel, predictorModel] = thisScore | |
print(' Done! Score: {:.3f} -- Time elapsed: {:.3f} seconds'.format(thisScore, time.time() - t0)) | |
scores.to_csv(SAVEFILE) | |
return scores | |
if __name__ == '__main__': | |
SAVEFILE = os.path.join(os.path.dirname(__file__), 'results.csv') | |
if os.path.isfile(SAVEFILE): | |
scores = pd.read_csv(SAVEFILE, index_col = 0) | |
else: | |
scores = main() | |
scores.plot(ls = '-', marker = 'x') | |
plt.xticks(range(scores.shape[0]), scores.index.tolist()) | |
plt.ylabel('Brier score') | |
plt.legend(title = 'Predictor') | |
plt.tight_layout() | |
plt.savefig('results.png', dpi = 500, bbox_inches = 'tight') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment