Created
July 26, 2023 08:27
-
-
Save kururu-abdo/030e725fb1f36585d4700cc8e9e06bed to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def create_waffle_chart(categories, values, height, width, colormap, value_sign=''): | |
| # compute the proportion of each category with respect to the total | |
| total_values = sum(values) | |
| category_proportions = [(float(value) / total_values) for value in values] | |
| # compute the total number of tiles | |
| total_num_tiles = width * height # total number of tiles | |
| print ('Total number of tiles is', total_num_tiles) | |
| # compute the number of tiles for each catagory | |
| tiles_per_category = [round(proportion * total_num_tiles) for proportion in category_proportions] | |
| # print out number of tiles per category | |
| for i, tiles in enumerate(tiles_per_category): | |
| print (df_dsn.index.values[i] + ': ' + str(tiles)) | |
| # initialize the waffle chart as an empty matrix | |
| waffle_chart = np.zeros((height, width)) | |
| # define indices to loop through waffle chart | |
| category_index = 0 | |
| tile_index = 0 | |
| # populate the waffle chart | |
| for col in range(width): | |
| for row in range(height): | |
| tile_index += 1 | |
| # if the number of tiles populated for the current category | |
| # is equal to its corresponding allocated tiles... | |
| if tile_index > sum(tiles_per_category[0:category_index]): | |
| # ...proceed to the next category | |
| category_index += 1 | |
| # set the class value to an integer, which increases with class | |
| waffle_chart[row, col] = category_index | |
| # instantiate a new figure object | |
| fig = plt.figure() | |
| # use matshow to display the waffle chart | |
| colormap = plt.cm.coolwarm | |
| plt.matshow(waffle_chart, cmap=colormap) | |
| plt.colorbar() | |
| # get the axis | |
| ax = plt.gca() | |
| # set minor ticks | |
| ax.set_xticks(np.arange(-.5, (width), 1), minor=True) | |
| ax.set_yticks(np.arange(-.5, (height), 1), minor=True) | |
| # add dridlines based on minor ticks | |
| ax.grid(which='minor', color='w', linestyle='-', linewidth=2) | |
| plt.xticks([]) | |
| plt.yticks([]) | |
| # compute cumulative sum of individual categories to match color schemes between chart and legend | |
| values_cumsum = np.cumsum(values) | |
| total_values = values_cumsum[len(values_cumsum) - 1] | |
| # create legend | |
| legend_handles = [] | |
| for i, category in enumerate(categories): | |
| if value_sign == '%': | |
| label_str = category + ' (' + str(values[i]) + value_sign + ')' | |
| else: | |
| label_str = category + ' (' + value_sign + str(values[i]) + ')' | |
| color_val = colormap(float(values_cumsum[i])/total_values) | |
| legend_handles.append(mpatches.Patch(color=color_val, label=label_str)) | |
| # add legend to chart | |
| plt.legend( | |
| handles=legend_handles, | |
| loc='lower center', | |
| ncol=len(categories), | |
| bbox_to_anchor=(0., -0.2, 0.95, .1) | |
| ) | |
| plt.show() | |
| #data to pass to function | |
| width = 40 # width of chart | |
| height = 10 # height of chart | |
| categories = df_dsn.index.values # categories | |
| values = df_dsn['Total'] # correponding values of categories | |
| colormap = plt.cm.coolwarm # color map class | |
| create_waffle_chart(categories, values, height, width, colormap) | |
| from pywaffle import Waffle | |
| #Set up the Waffle chart figure | |
| fig = plt.figure(FigureClass = Waffle, | |
| rows = 20, columns = 30, #pass the number of rows and columns for the waffle | |
| values = df_dsn['Total'], #pass the data to be used for display | |
| cmap_name = 'tab20', #color scheme | |
| legend = {'labels': [f"{k} ({v})" for k, v in zip(df_dsn.index.values,df_dsn.Total)], | |
| 'loc': 'lower left', 'bbox_to_anchor':(0,-0.1),'ncol': 3} | |
| #notice the use of list comprehension for creating labels | |
| #from index and total of the dataset | |
| ) | |
| #Display the waffle chart | |
| plt.show() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment