import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import seaborn as sns

colors = {'bg':'#3D3C42', 'data':'#FEFBF6', 'data2':'#A6D1E6'}

# data
url = 'https://gist.githubusercontent.com/Thiagobc23/0bc5c5f530e8f8fd98710fe1ccb407ce/raw/4e084668a83ab9d0a0ace1425742835a0563bcef/quality.csv'

df = pd.read_csv(url)
df_unit = df.groupby('Unit').mean().reset_index()

# sort prod units by avg score
df_unit = df_unit.sort_values('Quality Score', ascending=False)
order = df_unit['Unit'].values.tolist()

# Style and Subplots
#sns.set_style("whitegrid")
fig, (ax, ax_leg) = plt.subplots(2, figsize=(12,10), facecolor=colors['bg'], gridspec_kw={'height_ratios':[6, 1]})
ax.set_facecolor(colors['bg'])
ax_leg.set_facecolor(colors['bg'])

# Plot jitter chart
sns.stripplot(y="Unit", x="Quality Score", data=df, order=order, jitter=0.2, size=4, alpha=0.8, color=colors['data'], ax=ax)
# Draw lines for mean value
ax.scatter(df_unit['Quality Score'], df_unit['Unit'], marker='|', s=2500, color=colors['data2'], alpha=1, zorder=99)
ax.scatter(df_unit['Quality Score']+2, df_unit['Unit'], marker='|', s=2500, color=colors['data2'], alpha=1, zorder=99)
ax.scatter(df_unit['Quality Score']-2, df_unit['Unit'], marker='|', s=2500, color=colors['data2'], alpha=1, zorder=99)

# Scales and Ticks
ax.set_ylim(len(order)-0.5, -0.5)
ax.set_xlim(0,1001)
xticks = np.arange(0,1001,100)
ax.set_xticks(xticks, color=colors['data'])
ax.set_yticks(order, color=colors['data'])
ax.tick_params(axis='both', which='major', labelsize=12, colors=colors['data'])

# Labels and Title
ax.set_title('Quality score by production unit\n', loc='left', fontsize=20, color=colors['data'])
ax.set_ylabel('')
ax.set_xlabel('')
#ax.set_xlabel('Score', fontsize=12, color=colors['data'])

# custom Legend
legend_elements = [Line2D([0], [0], marker='o', color=colors['bg'], label='One test result', markerfacecolor=colors['data'], markersize=10),
                   Line2D([0], [0], marker='|', color=colors['data2'], label='Mean quality score for the unit', linestyle='None', markersize=25)]

legend = ax_leg.legend(handles=legend_elements, loc='upper center', ncol=5, frameon=False)
plt.setp(legend.get_texts(), color=colors['data'], fontsize=12)

# clean second axis
ax_leg.set_xticks([])
ax_leg.set_yticks([])

sns.despine(fig, left=True, top=True, right=True, bottom=True)

plt.tight_layout()
plt.savefig('jitter.png', facecolor=colors['bg'])