Skip to content

Instantly share code, notes, and snippets.

@gallir
Created April 3, 2021 22:08
Show Gist options
  • Save gallir/e8a0577c2a3067e248d2baf83fd4a165 to your computer and use it in GitHub Desktop.
Save gallir/e8a0577c2a3067e248d2baf83fd4a165 to your computer and use it in GitHub Desktop.
covid prediction with fbprophet
#! /usr/bin/env python3
import pandas as pd
import pycountry
from datetime import date, timedelta
from fbprophet import Prophet
from concurrent.futures import ProcessPoolExecutor
from .base import Series
COVID_URL = "https://raw.githubusercontent.com/owid/covid-19-data/master/public/data/owid-covid-data.csv"
COUNTRIES_TO_2 = dict()
for c in pycountry.countries:
COUNTRIES_TO_2[c.alpha_3] = c.alpha_2.lower()
MIN_START = date.fromisoformat('2021-01-01')
class Covid(Series):
schema = {
"timestamp": "timestamp",
"item_id": "string",
"vaccines": "float",
}
columns = ["vaccines", "fully_vaccinated"]
timestamp_format = "yyyy-MM-dd"
def __init__(self, start, end, freq='W-MON', top=0, days=90, isos=None, use_cache=False):
super().__init__("covid", self.schema, start, end, freq, top=top, use_cache=use_cache)
self.isos = isos
prediction_end = end + timedelta(days=days)
range = pd.date_range(start=start, end=prediction_end, freq=freq, closed="left")
self.futures = pd.DataFrame(range, columns=['ds'])
self.base = pd.DataFrame(range, columns=['timestamp'])
def read(self):
# Read and fill data
super().read()
if self.df is not None:
return
print(f"Getting COVID data from {self.start} to {self.end}")
df = self.get_data()
self.df = self.predict(df)
self.store_cache()
def get_data(self):
df = pd.read_csv(
COVID_URL,
usecols=[
"date",
"iso_code",
"total_vaccinations_per_hundred",
# "people_fully_vaccinated_per_hundred",
],
parse_dates=["date"],
)
df = df.rename(columns={
"date": "timestamp",
"iso_code": "iso",
"total_vaccinations_per_hundred": "vaccines",
# "people_fully_vaccinated_per_hundred": "fully_vaccinated",
})
df = df.loc[(df.timestamp >= self.start.isoformat()) & (df.timestamp < self.end.isoformat())]
df["iso"] = df["iso"].fillna("").replace(COUNTRIES_TO_2)
df = df[df['iso'].map(len) == 2]
return df
def predict(self, df):
all_df = pd.DataFrame()
data_isos = set(df['iso'].unique())
if self.isos is None:
self.isos = data_isos
def concat(f):
nonlocal all_df
r = f.result()
all_df = pd.concat([r, all_df])
with ProcessPoolExecutor(max_workers=4) as ex:
for iso in self.isos:
if iso not in data_isos:
print(f"Country {iso} not found")
iso_df['iso'] = iso
for c in self.columns:
iso_df[c] = 0.0
all_df = pd.concat([iso_df, all_df])
continue
df_c = df.loc[df['iso'] == iso]
iso_df = self.base.copy()
df_c = df.loc[df['iso'] == iso]
f = ex.submit(self.process_country, iso, df_c)
f.add_done_callback(concat)
all_df.sort_values(["timestamp", "iso"], inplace=True)
return all_df
def process_country(self, iso, data_df):
print(f"Processing {iso}")
futures = self.futures
base_df = self.base
for c in self.columns:
if c not in data_df.columns:
continue
pdf = data_df[['timestamp', c]].copy()
pdf[c] = pdf[c].fillna(method='ffill').fillna(0)
if self.freq != 'D':
pdf = pdf.resample(self.freq, on="timestamp", label="left").max()
pdf.rename(columns={c: 'y', 'timestamp': 'ds'}, inplace=True)
if c == "fully_vaccinated":
pdf['cap'] = 100
elif c == "vaccines":
pdf['cap'] = 200
pdf['floor'] = 0
pdf = pdf.loc[pdf.y > 0]
if pdf.shape[0] < 2:
base_df[c] = 0.0
continue
min_date = pdf.iloc[0]['ds']
f = futures.loc[futures.ds >= min_date]
model = Prophet(yearly_seasonality=False, daily_seasonality=False, growth='linear')
model.fit(pdf)
forecast = model.predict(f)
res = forecast[['ds', 'yhat']].copy()
res.rename(columns={'yhat': c, 'ds': 'timestamp'}, inplace=True)
base_df = base_df.merge(res, on="timestamp", how='left')
base_df[c] = base_df[c].fillna(0).apply(lambda x: x if x > 0 else 0)
base_df['iso'] = iso
return base_df
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment