Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save darthsuogles/15528ad827ab6dab4b4582c9ec9c135d to your computer and use it in GitHub Desktop.

Select an option

Save darthsuogles/15528ad827ab6dab4b4582c9ec9c135d to your computer and use it in GitHub Desktop.
import regex
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
CURRENCY_REGEX = regex.compile(r'[\$€,]')
# Load the publicly available World Economic Outlook dataset
class WorldEconomicOutlook(object):
def __init__(self):
self._fname = "data/WEOOct2017all.tsv"
self._df = self._parse()
@property
def dataset(self):
return self._df
# Notice that property will recompute the content
def _parse(self):
print("""\
IMF World Economic Outlook Dataset
>> http://www.imf.org/external/pubs/ft/weo/2017/02/weodata/index.aspx
""")
df = pd.read_csv(self._fname,
# Despite the original "xls" extension, the file is actually in TSV.
delimiter="\t",
# The first row in the file represents the header.
header=0,
# Drop the last row as it is
# "International Monetary Fund, World Economic Outlook Database, October 2017"
skipfooter=1,
# Encoding, checked from the `file` command.
encoding="ISO-8859-1")
# Some empty lines are not identified as such. We thus drop all of them after parsing.
df.dropna(axis='index', how='all', inplace=True)
df['WEO Country Code'] = df['WEO Country Code'].astype('int')
return df
IMF_WEO = WorldEconomicOutlook()
df = IMF_WEO.dataset
# WEO country code to country name
country_code2name = df[['WEO Country Code', 'ISO', 'Country']].drop_duplicates().set_index('WEO Country Code')
# WEO related metadata
weo_subject_cols = ['WEO Subject Code', 'Subject Descriptor', 'Units', 'Scale']
weo_subject_info_df = df[weo_subject_cols].drop_duplicates()
feature_cols = []
year_cols = []
for col in df.columns:
try:
maybe_year = int(col)
if maybe_year < 1900 or maybe_year > 2300:
raise ValueError
year_cols.append(maybe_year)
except ValueError:
# Captures column that cannot be converted to year
feature_cols.append(col)
def get_values_for_year(year):
year = str(year)
year_srs = df[['WEO Country Code', 'WEO Subject Code', year]].drop_duplicates()
year_srs.rename(columns={
"WEO Country Code": "country",
"WEO Subject Code": "variable",
year: "value"
}, inplace=True)
return year_srs.pivot(index="country", columns="variable", values="value")
def try_get_currency_number(num_repr):
try:
return float(CURRENCY_REGEX.sub('', num_repr))
except (ValueError, TypeError):
return np.nan
year_srs = (get_values_for_year(2017)
.applymap(try_get_currency_number))
# Check out distribution of individual variables like this
# year_srs['LE'].apply(np.log).plot('hist'); plt.show()
# Looks like it's fine to just apply `log` to all of them
year_srs_log = year_srs.apply(np.log)
# Covariance and correlation with pandas ignores NA
# X.cov(); X.corr()
# We have a lot of missing values
missing_values = year_srs_log.isna().sum(axis=0)
# This looks like a reasonable thing to start with
year_srs_proc = (year_srs_log
# Replace infinity values with NaN
.replace([np.inf, -np.inf], np.nan)
# Remove columns with too many missing values
.loc[:, missing_values / X.shape[0] < 0.1]
# Remove rows with the missing values
.dropna(how='any'))
from sklearn import preprocessing
from sklearn.feature_selection import VarianceThreshold
X = year_srs_proc
X_var_pruned = VarianceThreshold(threshold=0.2).fit_transform(X)
X_scaled = preprocessing.scale(X_var_pruned)
# Gaussian mixture with Dirichlet Process
from sklearn.mixture import BayesianGaussianMixture
max_num_components = 72
bnp = BayesianGaussianMixture(weight_concentration_prior_type="dirichlet_process",
n_components=max_num_components,
reg_covar=0,
init_params='random',
max_iter=1500,
mean_precision_prior=.8,
random_state=2702)
bnp.fit(X_scaled)
cluster_ids = bnp.predict(X_scaled)
year_srs_proc['bnp_cluster_id'] = cluster_ids
cluster_ids_with_country = (year_srs_proc[['bnp_cluster_id']]
# Join on index
.join(country_code2name)
.sort_values('bnp_cluster_id'))
"""
Generate cluster overview plots with leaflet
"""
import folium
import webbrowser
import webcolors
from colorsys import hls_to_rgb
def get_distinct_colors(vals):
colors = {}
for val, idx in zip(vals, np.arange(0., 360., 360. / len(vals))):
h = idx / 360.
l = (50 + np.random.rand() * 10) / 100.
s = (90 + np.random.rand() * 10) / 100.
rgb_repr = ['{:2.2f}%'.format(c * 100) for c in hls_to_rgb(h, l, s)]
colors[val] = webcolors.rgb_percent_to_hex(rgb_repr)
return colors
# Build a custom colormap
cmap = get_distinct_colors(set(cluster_ids))
# Create a leaflet map base layer
world_map = folium.Map(location=[45.523, -122.675],
tiles='cartodbpositron',
zoom_start=2,
no_wrap=True)
# Add coloring function
def style_function(feature):
country_iso = feature['id'][-5:]
opacity = 0.618
try:
_where = cluster_ids_with_country['ISO'] == country_iso
cluster_id = cluster_ids_with_country[_where]['bnp_cluster_id'].values[0]
color = cmap[cluster_id]
except:
color = '#black'
opacity = 0.05
return {
'fillOpacity': opacity,
'weight': 0.1,
'fillColor': color
}
# Attach a geojson layer with our custom coloring function
folium.GeoJson("data/countries.geo.json",
name='cluster_map',
style_function=style_function).add_to(world_map)
# Save and show the map
world_map.save("map.html"); webbrowser.open("map.html")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment