Skip to content

Instantly share code, notes, and snippets.

@mavroprovato
Last active January 12, 2021 23:01
Show Gist options
  • Save mavroprovato/124d2d8996588ee19853bd176f18e853 to your computer and use it in GitHub Desktop.
Save mavroprovato/124d2d8996588ee19853bd176f18e853 to your computer and use it in GitHub Desktop.
import os
import matplotlib.pyplot as plt
import pandas as pd
import requests
import seaborn as sns
sns.set_theme()
# The URL that contains the dataset
DATASET_URL = 'https://covid.ourworldindata.org/data/owid-covid-data.csv'
# The local dataset file
DATASET_FILE = 'data/owid-covid-data.csv'
def main():
data = get_data()
print_total_deaths(data, 'GRC', 'BEL')
graph_per_day(data, 'GRC', 'BEL', 'new_deaths_per_million', 'Deaths per million', '2020-09-01')
graph_per_day(data, 'GRC', 'BEL', 'new_deaths_smoothed_per_million', 'Deaths per million - 7 day average',
'2020-09-01')
def print_total_deaths(data: pd.DataFrame, country_1: str, country_2: str):
print(f'Total Deaths {country_1}',
round(data.loc[data.iso_code == country_1, ['total_deaths_per_million']].max().values[0]))
print(f'Total Deaths {country_2}',
round(data.loc[data.iso_code == country_2, ['total_deaths_per_million']].max().values[0]))
def graph_per_day(data: pd.DataFrame, country_1: str, country_2: str, column='new_deaths_per_million',
title='Deaths per million', cutoff_date: str = None):
plt.tight_layout()
data = data.loc[
(data.iso_code.isin([country_1, country_2])) & (data.date >= cutoff_date),
['date', 'iso_code', column]
]
data = data.pivot(index='date', columns='iso_code', values=column)
ax = data[[country_1, country_2]].plot(title=title)
ax.set_xlabel('Date')
ax.legend(title='Country')
fig = ax.get_figure()
fig.tight_layout()
fig.savefig(f'{column}.png')
data['DIFF'] = data.apply(lambda x: x[country_2] / x[country_1], axis=1)
data['THRESHOLD'] = 12
ax = data[['DIFF', 'THRESHOLD']].plot(title=f'{title} - Times Better', style=['-', '.'], legend=False)
ax.set_xlabel('Date')
fig = ax.get_figure()
fig.tight_layout()
fig.savefig(f'{column}_times_better.png')
print('Dates better:', ', '.join(data[data.DIFF > data.THRESHOLD].index.astype(str)))
plt.show()
def get_data() -> pd.DataFrame:
if not os.path.isfile(DATASET_FILE):
response = requests.get(DATASET_URL)
os.makedirs('data', exist_ok=True)
with open(DATASET_FILE, 'wt') as f:
f.write(response.text)
data = pd.read_csv(DATASET_FILE)
data['date'] = pd.to_datetime(data['date'])
return data
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment