Created
April 19, 2024 15:37
-
-
Save Yuxin-CV/8a48464319d47596a8bc790957cda9ff to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import numpy as np | |
from sklearn.datasets import load_wine | |
from sklearn.preprocessing import StandardScaler | |
from sklearn.pipeline import make_pipeline | |
from sklearn.cluster import KMeans | |
import matplotlib.pyplot as plt | |
import textwrap | |
import math | |
class ComplexRadar(): | |
""" | |
Create a complex radar chart with different scales for each variable | |
Parameters | |
---------- | |
fig : figure object | |
A matplotlib figure object to add the axes on | |
variables : list | |
A list of variables | |
ranges : list | |
A list of tuples (min, max) for each variable | |
n_ring_levels: int, defaults to 5 | |
Number of ordinate or ring levels to draw | |
show_scales: bool, defaults to True | |
Indicates if we the ranges for each variable are plotted | |
format_cfg: dict, defaults to None | |
A dictionary with formatting configurations | |
""" | |
def __init__(self, fig, variables, ranges, n_ring_levels=5, show_scales=True, format_cfg=None): | |
# Default formatting | |
self.format_cfg = { | |
# Axes | |
# https://matplotlib.org/stable/api/figure_api.html | |
'axes_args': {}, | |
# Tick labels on the scales | |
# https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.rgrids.html | |
'rgrid_tick_lbls_args': {'fontsize':8}, | |
# Radial (circle) lines | |
# https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.grid.html | |
'rad_ln_args': {}, | |
# Angle lines | |
# https://matplotlib.org/3.2.2/api/_as_gen/matplotlib.lines.Line2D.html#matplotlib.lines.Line2D | |
'angle_ln_args': {}, | |
# Include last value (endpoint) on scale | |
'incl_endpoint':False, | |
# Variable labels (ThetaTickLabel) | |
'theta_tick_lbls':{'va':'top', 'ha':'center'}, | |
'theta_tick_lbls_txt_wrap':15, | |
'theta_tick_lbls_brk_lng_wrds':False, | |
'theta_tick_lbls_pad':25, | |
# Outer ring | |
# https://matplotlib.org/stable/api/spines_api.html | |
'outer_ring':{'visible':True, 'color':'#d6d6d6'} | |
} | |
if format_cfg is not None: | |
self.format_cfg = { k:(format_cfg[k]) if k in format_cfg.keys() else (self.format_cfg[k]) | |
for k in self.format_cfg.keys()} | |
# Calculate angles and create for each variable an axes | |
# Consider here the trick with having the first axes element twice (len+1) | |
angles = np.arange(0, 360, 360./len(variables)) | |
axes = [fig.add_axes([0.1,0.1,0.9,0.9], | |
polar=True, | |
label = "axes{}".format(i), | |
**self.format_cfg['axes_args']) for i in range(len(variables)+1)] | |
# Ensure clockwise rotation (first variable at the top N) | |
for ax in axes: | |
ax.set_theta_zero_location('N') | |
ax.set_theta_direction(-1) | |
ax.set_axisbelow(True) | |
# Writing the ranges on each axes | |
for i, ax in enumerate(axes): | |
# Here we do the trick by repeating the first iteration | |
j = 0 if (i==0 or i==1) else i-1 | |
ax.set_ylim(*ranges[j]) | |
# Set endpoint to True if you like to have values right before the last circle | |
grid = np.linspace(*ranges[j], num=n_ring_levels, | |
endpoint=self.format_cfg['incl_endpoint']) | |
gridlabel = ["{}".format(round(x,2)) for x in grid] | |
gridlabel[0] = "" # remove values from the center | |
lines, labels = ax.set_rgrids(grid, | |
labels=gridlabel, | |
angle=angles[j], | |
**self.format_cfg['rgrid_tick_lbls_args'] | |
) | |
ax.set_ylim(*ranges[j]) | |
ax.spines["polar"].set_visible(False) | |
ax.grid(visible=False) | |
if show_scales == False: | |
ax.set_yticklabels([]) | |
# Set all axes except the first one unvisible | |
for ax in axes[1:]: | |
ax.patch.set_visible(False) | |
ax.xaxis.set_visible(False) | |
# Setting the attributes | |
self.angle = np.deg2rad(np.r_[angles, angles[0]]) | |
self.ranges = ranges | |
self.ax = axes[0] | |
self.ax1 = axes[1] | |
self.plot_counter = 0 | |
# Draw (inner) circles and lines | |
self.ax.yaxis.grid(**self.format_cfg['rad_ln_args']) | |
# Draw outer circle | |
self.ax.spines['polar'].set(**self.format_cfg['outer_ring']) | |
# Draw angle lines | |
self.ax.xaxis.grid(**self.format_cfg['angle_ln_args']) | |
# ax1 is the duplicate of axes[0] (self.ax) | |
# Remove everything from ax1 except the plot itself | |
self.ax1.axis('off') | |
self.ax1.set_zorder(9) | |
# Create the outer labels for each variable | |
l, text = self.ax.set_thetagrids(angles, labels=variables) | |
# Beautify them | |
labels = [t.get_text() for t in self.ax.get_xticklabels()] | |
labels = ['\n'.join(textwrap.wrap(l, self.format_cfg['theta_tick_lbls_txt_wrap'], | |
break_long_words=self.format_cfg['theta_tick_lbls_brk_lng_wrds'])) for l in labels] | |
self.ax.set_xticklabels(labels, **self.format_cfg['theta_tick_lbls']) | |
for t,a in zip(self.ax.get_xticklabels(),angles): | |
# if a == 0: | |
# t.set_ha('center') | |
# elif a > 0 and a < 180: | |
# t.set_ha('left') | |
# elif a == 180: | |
# t.set_ha('center') | |
# else: | |
# t.set_ha('right') | |
t.set_ha('center') | |
self.ax.tick_params(axis='both', pad=self.format_cfg['theta_tick_lbls_pad']) | |
def _scale_data(self, data, ranges): | |
"""Scales data[1:] to ranges[0]""" | |
for d, (y1, y2) in zip(data[1:], ranges[1:]): | |
assert (y1 <= d <= y2) or (y2 <= d <= y1) | |
x1, x2 = ranges[0] | |
d = data[0] | |
sdata = [d] | |
for d, (y1, y2) in zip(data[1:], ranges[1:]): | |
sdata.append((d-y1) / (y2-y1) * (x2 - x1) + x1) | |
return sdata | |
def plot(self, data, *args, **kwargs): | |
"""Plots a line""" | |
sdata = self._scale_data(data, self.ranges) | |
self.ax1.plot(self.angle, np.r_[sdata, sdata[0]], *args, **kwargs) | |
self.plot_counter = self.plot_counter+1 | |
def fill(self, data, *args, **kwargs): | |
"""Plots an area""" | |
sdata = self._scale_data(data, self.ranges) | |
self.ax1.fill(self.angle, np.r_[sdata, sdata[0]], *args, **kwargs) | |
def use_legend(self, *args, **kwargs): | |
"""Shows a legend""" | |
self.ax1.legend(*args, **kwargs) | |
def set_title(self, title, pad=25, **kwargs): | |
"""Set a title""" | |
self.ax.set_title(title,pad=pad, **kwargs) | |
def set_text(self, *args, **kwargs): | |
self.ax.set_text(*args, **kwargs) | |
methods = ["EVA-02 (304M)", "EVA (1011M)"] | |
data = { | |
"fine-tuned IN-1K cls (1K val)": [90.0, 89.7], | |
"fine-tuned IN-1K cls (1K variants)": [85.2, 84.0], | |
"zero-shot IN-1K cls (1K val)": [80.4, 78.5], | |
"zero-shot IN-1K cls (avg. 27 datasets)": [73.5, 71.4], | |
"zero-shot video cls (avg. 4 datasets)": [67.7, 66.0], | |
"zero-shot T2I (COCO)": [71.5, 68.5], | |
"zero-shot T2I (Flickr30K)": [94.2, 91.6], | |
"zero-shot I2T (COCO)": [85.2, 83.3], | |
"zero-shot I2T (Flickr30K)": [98.9, 98.3], | |
"object detection (COCO)": [64.5, 64.4], | |
"object detection (LVIS)": [65.2, 62.2], | |
"instance segmentation (COCO)": [55.8, 55.5], | |
"instance segmentation (LVIS)": [57.3, 55.0], | |
"semantic segmentation (COCO164K)": [53.7, 53.4], | |
"semantic segmentation (ADE20K)": [62.0, 62.3], | |
} | |
data_min_max = { | |
"fine-tuned IN-1K cls (1K val)": [89.7-1.5, 89.7+3], | |
"fine-tuned IN-1K cls (1K variants)": [84.0-1.5, 84.0+3], | |
"zero-shot IN-1K cls (1K val)": [78.5-1.5, 78.5+3], | |
"zero-shot IN-1K cls (avg. 27 datasets)": [71.4-1.5, 71.4+3], | |
"zero-shot video cls (avg. 4 datasets)": [66.0-1.5, 66.0+3], | |
"zero-shot T2I (COCO)": [68.5-1.5, 68.5+3], | |
"zero-shot T2I (Flickr30K)": [91.6-1.5, 91.6+3], | |
"zero-shot I2T (COCO)": [83.3-1.5, 83.3+3], | |
"zero-shot I2T (Flickr30K)": [98.3-1.5, 98.3+3], | |
"object detection (COCO)": [64.4-1.5, 64.4+3], | |
"object detection (LVIS)": [62.2-1.5, 62.2+3], | |
"instance segmentation (COCO)": [55.5-1.5, 55.5+3], | |
"instance segmentation (LVIS)": [55.0-1.5, 55.0+3], | |
"semantic segmentation (COCO164K)": [53.4-1.5, 53.4+3], | |
"semantic segmentation (ADE20K)": [62.3-1.5, 62.3+3], | |
} | |
new_data = {} | |
for k, v in data.items(): | |
v_min = data_min_max[k][0] | |
new_v = [] | |
for vv in v: | |
if vv == 'N/A': | |
new_v.append(v_min) | |
else: | |
new_v.append(vv) | |
new_data[k] = new_v | |
print(new_data) | |
data = new_data | |
ranges = list(data_min_max.values()) | |
variables = data.keys() | |
format_cfg = { | |
'rad_ln_args': {'visible':True}, | |
'angle_ln_args':{'visible':True}, | |
'rgrid_tick_lbls_args': {'fontsize':12}, | |
'theta_tick_lbls_pad': 30, | |
'outer_ring':{'visible':False, 'color':'#d6d6d6'}, | |
'theta_tick_lbls':{'va':'center', 'ha':'center', 'fontsize':14}, | |
} | |
fig1 = plt.figure(figsize=(6, 6)) | |
radar = ComplexRadar(fig1, variables, ranges, n_ring_levels=3, show_scales=True, format_cfg=format_cfg) | |
custom_colors = ['#ed2323', '#965fd4'] | |
custom_alphas = [0.08, 0.08] | |
for g in [1, 0]: | |
radar.plot([i[g] for i in list(data.values())], label=f"{methods[g]}", color=custom_colors[g]) | |
radar.fill([i[g] for i in list(data.values())], alpha=custom_alphas[g], color=custom_colors[g]) | |
# radar.set_title("Radar chart solution with different scales", pad=25) | |
radar.use_legend(**{'loc':'lower right', 'bbox_to_anchor':(0.95, -0.28), 'ncol':radar.plot_counter, 'fontsize':15, 'ncol':2, 'frameon':False}) | |
plt.text(0.56, 0.56, 'EVA', transform=plt.gcf().transFigure, fontsize=12) | |
plt.savefig('./eva_radar.png', dpi=128, bbox_inches='tight') | |
plt.savefig('./eva_radar.pdf', dpi=128, bbox_inches='tight') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment