This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import base64 | |
import shlex | |
from pathlib import Path | |
from dataclasses import dataclass | |
from typing import Any | |
import click | |
import runpod | |
from dotenv import load_dotenv, dotenv_values, find_dotenv |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import os | |
import sys | |
from dataclasses import dataclass, field | |
from typing import Optional | |
import torch | |
import transformers | |
from datasets import load_dataset | |
from torchvision.transforms import ( |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sir_models.utils import eval_on_select_dates_and_k_days_ahead | |
from sir_models.utils import smape | |
from sklearn.metrics import mean_absolute_error | |
K = 30 | |
last_day = train_subset.date.iloc[-1] - pd.to_timedelta(K, unit='D') | |
eval_dates = pd.date_range(start='2020-06-01', end=last_day)[::20] | |
def eval_hidden_moscow(train_df, t, train_t, eval_t): | |
weights = { |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sir_models.fitters import HiddenCurveFitter | |
from sir_models.models import SEIRHidden | |
stepwize_size = 60 | |
weights = { | |
'I': 0.25, | |
'R': 0.25, | |
'D': 0.5, | |
} | |
model = SEIRHidden(stepwise_size=stepwize_size) | |
fitter = HiddenCurveFitter( |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def smape_resid_transform(true, pred, eps=1e-5): | |
return (true - pred) / (np.abs(true) + np.abs(pred) + eps) | |
class HiddenCurveFitter(BaseFitter): | |
... | |
def residual(self, params, t_vals, data, model): | |
model.params = params | |
initial_conditions = model.get_initial_conditions(data) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def sigmoid(x, xmin, xmax, a, b, c, r): | |
x_scaled = (x - xmin) / (xmax - xmin) | |
out = (a * np.exp(c * r) + b * np.exp(r * x_scaled)) / (np.exp(c * r) + np.exp(x_scaled * r)) | |
return out | |
def stepwise_soft(t, coefficients, r=20, c=0.5): | |
t_arr = np.array(list(coefficients.keys())) | |
min_index = np.min(t_arr) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model = BarebonesSEIR() | |
model.params = model.get_fit_params() | |
train_initial_conditions = model.get_initial_conditions(train_subset) | |
train_t = np.arange(len(train_subset)) | |
(S, E, I, R, D) = model.predict(train_t, train_initial_conditions) | |
plt.figure(figsize=(10, 7)) | |
plt.plot(train_subset.date, train_subset['total_dead'], label='ground truth') | |
plt.plot(train_subset.date, D, label='predicted', color='black', linestyle='dashed' ) | |
plt.legend() | |
plt.title('Total deaths') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class BarebonesSEIR: | |
def __init__(self, params=None): | |
self.params = params | |
def get_fit_params(self): | |
params = lmfit.Parameters() | |
params.add("population", value=12_000_000, vary=False) | |
params.add("epidemic_started_days_ago", value=10, vary=False) | |
params.add("r0", value=4, min=3, max=5, vary=True) | |
params.add("alpha", value=0.0064, min=0.005, max=0.0078, vary=True) # CFR |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
df_smoothed = df.rolling(7).mean().round(5) | |
df_smoothed.columns = [col + '_ma7' for col in df_smoothed.columns] | |
full_df = pd.concat([df, df_smoothed], axis=1) | |
for column in full_df.columns: | |
if column.endswith('_ma7'): | |
original_column = column.strip('_ma7') | |
full_df[column] = full_df[column].fillna(full_df[original_column]) |
NewerOlder