Skip to content

Instantly share code, notes, and snippets.

@Yuxin-CV
Created April 19, 2024 15:37
Show Gist options
  • Save Yuxin-CV/8a48464319d47596a8bc790957cda9ff to your computer and use it in GitHub Desktop.
Save Yuxin-CV/8a48464319d47596a8bc790957cda9ff to your computer and use it in GitHub Desktop.
import pandas as pd
import numpy as np
from sklearn.datasets import load_wine
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import textwrap
import math
class ComplexRadar():
"""
Create a complex radar chart with different scales for each variable
Parameters
----------
fig : figure object
A matplotlib figure object to add the axes on
variables : list
A list of variables
ranges : list
A list of tuples (min, max) for each variable
n_ring_levels: int, defaults to 5
Number of ordinate or ring levels to draw
show_scales: bool, defaults to True
Indicates if we the ranges for each variable are plotted
format_cfg: dict, defaults to None
A dictionary with formatting configurations
"""
def __init__(self, fig, variables, ranges, n_ring_levels=5, show_scales=True, format_cfg=None):
# Default formatting
self.format_cfg = {
# Axes
# https://matplotlib.org/stable/api/figure_api.html
'axes_args': {},
# Tick labels on the scales
# https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.rgrids.html
'rgrid_tick_lbls_args': {'fontsize':8},
# Radial (circle) lines
# https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.grid.html
'rad_ln_args': {},
# Angle lines
# https://matplotlib.org/3.2.2/api/_as_gen/matplotlib.lines.Line2D.html#matplotlib.lines.Line2D
'angle_ln_args': {},
# Include last value (endpoint) on scale
'incl_endpoint':False,
# Variable labels (ThetaTickLabel)
'theta_tick_lbls':{'va':'top', 'ha':'center'},
'theta_tick_lbls_txt_wrap':15,
'theta_tick_lbls_brk_lng_wrds':False,
'theta_tick_lbls_pad':25,
# Outer ring
# https://matplotlib.org/stable/api/spines_api.html
'outer_ring':{'visible':True, 'color':'#d6d6d6'}
}
if format_cfg is not None:
self.format_cfg = { k:(format_cfg[k]) if k in format_cfg.keys() else (self.format_cfg[k])
for k in self.format_cfg.keys()}
# Calculate angles and create for each variable an axes
# Consider here the trick with having the first axes element twice (len+1)
angles = np.arange(0, 360, 360./len(variables))
axes = [fig.add_axes([0.1,0.1,0.9,0.9],
polar=True,
label = "axes{}".format(i),
**self.format_cfg['axes_args']) for i in range(len(variables)+1)]
# Ensure clockwise rotation (first variable at the top N)
for ax in axes:
ax.set_theta_zero_location('N')
ax.set_theta_direction(-1)
ax.set_axisbelow(True)
# Writing the ranges on each axes
for i, ax in enumerate(axes):
# Here we do the trick by repeating the first iteration
j = 0 if (i==0 or i==1) else i-1
ax.set_ylim(*ranges[j])
# Set endpoint to True if you like to have values right before the last circle
grid = np.linspace(*ranges[j], num=n_ring_levels,
endpoint=self.format_cfg['incl_endpoint'])
gridlabel = ["{}".format(round(x,2)) for x in grid]
gridlabel[0] = "" # remove values from the center
lines, labels = ax.set_rgrids(grid,
labels=gridlabel,
angle=angles[j],
**self.format_cfg['rgrid_tick_lbls_args']
)
ax.set_ylim(*ranges[j])
ax.spines["polar"].set_visible(False)
ax.grid(visible=False)
if show_scales == False:
ax.set_yticklabels([])
# Set all axes except the first one unvisible
for ax in axes[1:]:
ax.patch.set_visible(False)
ax.xaxis.set_visible(False)
# Setting the attributes
self.angle = np.deg2rad(np.r_[angles, angles[0]])
self.ranges = ranges
self.ax = axes[0]
self.ax1 = axes[1]
self.plot_counter = 0
# Draw (inner) circles and lines
self.ax.yaxis.grid(**self.format_cfg['rad_ln_args'])
# Draw outer circle
self.ax.spines['polar'].set(**self.format_cfg['outer_ring'])
# Draw angle lines
self.ax.xaxis.grid(**self.format_cfg['angle_ln_args'])
# ax1 is the duplicate of axes[0] (self.ax)
# Remove everything from ax1 except the plot itself
self.ax1.axis('off')
self.ax1.set_zorder(9)
# Create the outer labels for each variable
l, text = self.ax.set_thetagrids(angles, labels=variables)
# Beautify them
labels = [t.get_text() for t in self.ax.get_xticklabels()]
labels = ['\n'.join(textwrap.wrap(l, self.format_cfg['theta_tick_lbls_txt_wrap'],
break_long_words=self.format_cfg['theta_tick_lbls_brk_lng_wrds'])) for l in labels]
self.ax.set_xticklabels(labels, **self.format_cfg['theta_tick_lbls'])
for t,a in zip(self.ax.get_xticklabels(),angles):
# if a == 0:
# t.set_ha('center')
# elif a > 0 and a < 180:
# t.set_ha('left')
# elif a == 180:
# t.set_ha('center')
# else:
# t.set_ha('right')
t.set_ha('center')
self.ax.tick_params(axis='both', pad=self.format_cfg['theta_tick_lbls_pad'])
def _scale_data(self, data, ranges):
"""Scales data[1:] to ranges[0]"""
for d, (y1, y2) in zip(data[1:], ranges[1:]):
assert (y1 <= d <= y2) or (y2 <= d <= y1)
x1, x2 = ranges[0]
d = data[0]
sdata = [d]
for d, (y1, y2) in zip(data[1:], ranges[1:]):
sdata.append((d-y1) / (y2-y1) * (x2 - x1) + x1)
return sdata
def plot(self, data, *args, **kwargs):
"""Plots a line"""
sdata = self._scale_data(data, self.ranges)
self.ax1.plot(self.angle, np.r_[sdata, sdata[0]], *args, **kwargs)
self.plot_counter = self.plot_counter+1
def fill(self, data, *args, **kwargs):
"""Plots an area"""
sdata = self._scale_data(data, self.ranges)
self.ax1.fill(self.angle, np.r_[sdata, sdata[0]], *args, **kwargs)
def use_legend(self, *args, **kwargs):
"""Shows a legend"""
self.ax1.legend(*args, **kwargs)
def set_title(self, title, pad=25, **kwargs):
"""Set a title"""
self.ax.set_title(title,pad=pad, **kwargs)
def set_text(self, *args, **kwargs):
self.ax.set_text(*args, **kwargs)
methods = ["EVA-02 (304M)", "EVA (1011M)"]
data = {
"fine-tuned IN-1K cls (1K val)": [90.0, 89.7],
"fine-tuned IN-1K cls (1K variants)": [85.2, 84.0],
"zero-shot IN-1K cls (1K val)": [80.4, 78.5],
"zero-shot IN-1K cls (avg. 27 datasets)": [73.5, 71.4],
"zero-shot video cls (avg. 4 datasets)": [67.7, 66.0],
"zero-shot T2I (COCO)": [71.5, 68.5],
"zero-shot T2I (Flickr30K)": [94.2, 91.6],
"zero-shot I2T (COCO)": [85.2, 83.3],
"zero-shot I2T (Flickr30K)": [98.9, 98.3],
"object detection (COCO)": [64.5, 64.4],
"object detection (LVIS)": [65.2, 62.2],
"instance segmentation (COCO)": [55.8, 55.5],
"instance segmentation (LVIS)": [57.3, 55.0],
"semantic segmentation (COCO164K)": [53.7, 53.4],
"semantic segmentation (ADE20K)": [62.0, 62.3],
}
data_min_max = {
"fine-tuned IN-1K cls (1K val)": [89.7-1.5, 89.7+3],
"fine-tuned IN-1K cls (1K variants)": [84.0-1.5, 84.0+3],
"zero-shot IN-1K cls (1K val)": [78.5-1.5, 78.5+3],
"zero-shot IN-1K cls (avg. 27 datasets)": [71.4-1.5, 71.4+3],
"zero-shot video cls (avg. 4 datasets)": [66.0-1.5, 66.0+3],
"zero-shot T2I (COCO)": [68.5-1.5, 68.5+3],
"zero-shot T2I (Flickr30K)": [91.6-1.5, 91.6+3],
"zero-shot I2T (COCO)": [83.3-1.5, 83.3+3],
"zero-shot I2T (Flickr30K)": [98.3-1.5, 98.3+3],
"object detection (COCO)": [64.4-1.5, 64.4+3],
"object detection (LVIS)": [62.2-1.5, 62.2+3],
"instance segmentation (COCO)": [55.5-1.5, 55.5+3],
"instance segmentation (LVIS)": [55.0-1.5, 55.0+3],
"semantic segmentation (COCO164K)": [53.4-1.5, 53.4+3],
"semantic segmentation (ADE20K)": [62.3-1.5, 62.3+3],
}
new_data = {}
for k, v in data.items():
v_min = data_min_max[k][0]
new_v = []
for vv in v:
if vv == 'N/A':
new_v.append(v_min)
else:
new_v.append(vv)
new_data[k] = new_v
print(new_data)
data = new_data
ranges = list(data_min_max.values())
variables = data.keys()
format_cfg = {
'rad_ln_args': {'visible':True},
'angle_ln_args':{'visible':True},
'rgrid_tick_lbls_args': {'fontsize':12},
'theta_tick_lbls_pad': 30,
'outer_ring':{'visible':False, 'color':'#d6d6d6'},
'theta_tick_lbls':{'va':'center', 'ha':'center', 'fontsize':14},
}
fig1 = plt.figure(figsize=(6, 6))
radar = ComplexRadar(fig1, variables, ranges, n_ring_levels=3, show_scales=True, format_cfg=format_cfg)
custom_colors = ['#ed2323', '#965fd4']
custom_alphas = [0.08, 0.08]
for g in [1, 0]:
radar.plot([i[g] for i in list(data.values())], label=f"{methods[g]}", color=custom_colors[g])
radar.fill([i[g] for i in list(data.values())], alpha=custom_alphas[g], color=custom_colors[g])
# radar.set_title("Radar chart solution with different scales", pad=25)
radar.use_legend(**{'loc':'lower right', 'bbox_to_anchor':(0.95, -0.28), 'ncol':radar.plot_counter, 'fontsize':15, 'ncol':2, 'frameon':False})
plt.text(0.56, 0.56, 'EVA', transform=plt.gcf().transFigure, fontsize=12)
plt.savefig('./eva_radar.png', dpi=128, bbox_inches='tight')
plt.savefig('./eva_radar.pdf', dpi=128, bbox_inches='tight')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment