Last active
March 27, 2018 08:31
-
-
Save darthsuogles/15528ad827ab6dab4b4582c9ec9c135d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import regex | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| CURRENCY_REGEX = regex.compile(r'[\$€,]') | |
| # Load the publicly available World Economic Outlook dataset | |
| class WorldEconomicOutlook(object): | |
| def __init__(self): | |
| self._fname = "data/WEOOct2017all.tsv" | |
| self._df = self._parse() | |
| @property | |
| def dataset(self): | |
| return self._df | |
| # Notice that property will recompute the content | |
| def _parse(self): | |
| print("""\ | |
| IMF World Economic Outlook Dataset | |
| >> http://www.imf.org/external/pubs/ft/weo/2017/02/weodata/index.aspx | |
| """) | |
| df = pd.read_csv(self._fname, | |
| # Despite the original "xls" extension, the file is actually in TSV. | |
| delimiter="\t", | |
| # The first row in the file represents the header. | |
| header=0, | |
| # Drop the last row as it is | |
| # "International Monetary Fund, World Economic Outlook Database, October 2017" | |
| skipfooter=1, | |
| # Encoding, checked from the `file` command. | |
| encoding="ISO-8859-1") | |
| # Some empty lines are not identified as such. We thus drop all of them after parsing. | |
| df.dropna(axis='index', how='all', inplace=True) | |
| df['WEO Country Code'] = df['WEO Country Code'].astype('int') | |
| return df | |
| IMF_WEO = WorldEconomicOutlook() | |
| df = IMF_WEO.dataset | |
| # WEO country code to country name | |
| country_code2name = df[['WEO Country Code', 'ISO', 'Country']].drop_duplicates().set_index('WEO Country Code') | |
| # WEO related metadata | |
| weo_subject_cols = ['WEO Subject Code', 'Subject Descriptor', 'Units', 'Scale'] | |
| weo_subject_info_df = df[weo_subject_cols].drop_duplicates() | |
| feature_cols = [] | |
| year_cols = [] | |
| for col in df.columns: | |
| try: | |
| maybe_year = int(col) | |
| if maybe_year < 1900 or maybe_year > 2300: | |
| raise ValueError | |
| year_cols.append(maybe_year) | |
| except ValueError: | |
| # Captures column that cannot be converted to year | |
| feature_cols.append(col) | |
| def get_values_for_year(year): | |
| year = str(year) | |
| year_srs = df[['WEO Country Code', 'WEO Subject Code', year]].drop_duplicates() | |
| year_srs.rename(columns={ | |
| "WEO Country Code": "country", | |
| "WEO Subject Code": "variable", | |
| year: "value" | |
| }, inplace=True) | |
| return year_srs.pivot(index="country", columns="variable", values="value") | |
| def try_get_currency_number(num_repr): | |
| try: | |
| return float(CURRENCY_REGEX.sub('', num_repr)) | |
| except (ValueError, TypeError): | |
| return np.nan | |
| year_srs = (get_values_for_year(2017) | |
| .applymap(try_get_currency_number)) | |
| # Check out distribution of individual variables like this | |
| # year_srs['LE'].apply(np.log).plot('hist'); plt.show() | |
| # Looks like it's fine to just apply `log` to all of them | |
| year_srs_log = year_srs.apply(np.log) | |
| # Covariance and correlation with pandas ignores NA | |
| # X.cov(); X.corr() | |
| # We have a lot of missing values | |
| missing_values = year_srs_log.isna().sum(axis=0) | |
| # This looks like a reasonable thing to start with | |
| year_srs_proc = (year_srs_log | |
| # Replace infinity values with NaN | |
| .replace([np.inf, -np.inf], np.nan) | |
| # Remove columns with too many missing values | |
| .loc[:, missing_values / X.shape[0] < 0.1] | |
| # Remove rows with the missing values | |
| .dropna(how='any')) | |
| from sklearn import preprocessing | |
| from sklearn.feature_selection import VarianceThreshold | |
| X = year_srs_proc | |
| X_var_pruned = VarianceThreshold(threshold=0.2).fit_transform(X) | |
| X_scaled = preprocessing.scale(X_var_pruned) | |
| # Gaussian mixture with Dirichlet Process | |
| from sklearn.mixture import BayesianGaussianMixture | |
| max_num_components = 72 | |
| bnp = BayesianGaussianMixture(weight_concentration_prior_type="dirichlet_process", | |
| n_components=max_num_components, | |
| reg_covar=0, | |
| init_params='random', | |
| max_iter=1500, | |
| mean_precision_prior=.8, | |
| random_state=2702) | |
| bnp.fit(X_scaled) | |
| cluster_ids = bnp.predict(X_scaled) | |
| year_srs_proc['bnp_cluster_id'] = cluster_ids | |
| cluster_ids_with_country = (year_srs_proc[['bnp_cluster_id']] | |
| # Join on index | |
| .join(country_code2name) | |
| .sort_values('bnp_cluster_id')) | |
| """ | |
| Generate cluster overview plots with leaflet | |
| """ | |
| import folium | |
| import webbrowser | |
| import webcolors | |
| from colorsys import hls_to_rgb | |
| def get_distinct_colors(vals): | |
| colors = {} | |
| for val, idx in zip(vals, np.arange(0., 360., 360. / len(vals))): | |
| h = idx / 360. | |
| l = (50 + np.random.rand() * 10) / 100. | |
| s = (90 + np.random.rand() * 10) / 100. | |
| rgb_repr = ['{:2.2f}%'.format(c * 100) for c in hls_to_rgb(h, l, s)] | |
| colors[val] = webcolors.rgb_percent_to_hex(rgb_repr) | |
| return colors | |
| # Build a custom colormap | |
| cmap = get_distinct_colors(set(cluster_ids)) | |
| # Create a leaflet map base layer | |
| world_map = folium.Map(location=[45.523, -122.675], | |
| tiles='cartodbpositron', | |
| zoom_start=2, | |
| no_wrap=True) | |
| # Add coloring function | |
| def style_function(feature): | |
| country_iso = feature['id'][-5:] | |
| opacity = 0.618 | |
| try: | |
| _where = cluster_ids_with_country['ISO'] == country_iso | |
| cluster_id = cluster_ids_with_country[_where]['bnp_cluster_id'].values[0] | |
| color = cmap[cluster_id] | |
| except: | |
| color = '#black' | |
| opacity = 0.05 | |
| return { | |
| 'fillOpacity': opacity, | |
| 'weight': 0.1, | |
| 'fillColor': color | |
| } | |
| # Attach a geojson layer with our custom coloring function | |
| folium.GeoJson("data/countries.geo.json", | |
| name='cluster_map', | |
| style_function=style_function).add_to(world_map) | |
| # Save and show the map | |
| world_map.save("map.html"); webbrowser.open("map.html") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment