Skip to content

Instantly share code, notes, and snippets.

@izikeros
Last active December 19, 2024 19:10
Show Gist options
  • Select an option

  • Save izikeros/60947c52c13e49f888abb2e1ec06c942 to your computer and use it in GitHub Desktop.

Select an option

Save izikeros/60947c52c13e49f888abb2e1ec06c942 to your computer and use it in GitHub Desktop.
[plot enchancements] Tools for improving plots look and readability: add_value_labels_barh() and others #matplotlib

Here are some ideas for additional functions or classes that could enhance this module for creating good-looking reports with data visualizations:

  1. Color Palette Generator: A function that generates aesthetically pleasing color palettes for charts and graphs. This could include options for different types of palettes (e.g., sequential, diverging, qualitative) and considerations for color blindness.
def generate_color_palette(palette_type: str, num_colors: int, colorblind_safe: bool = True) -> List[str]:
    """Generate a color palette for data visualization."""
    ...
  1. Figure Layout Manager: A class to help manage the layout of multiple plots in a single figure, making it easier to create complex dashboards or report layouts.
class FigureLayoutManager:
    def __init__(self, fig_size: Tuple[float, float], grid_size: Tuple[int, int]):
        ...
    
    def add_plot(self, row: int, col: int, plot_function: Callable, *args, **kwargs):
        ...
    
    def adjust_layout(self):
        ...
  1. Custom Theme Applier: A function to apply a consistent custom theme across all plots in a report, including font styles, background colors, grid styles, etc.
def apply_custom_theme(fig: matplotlib.figure.Figure, theme: dict):
    """Apply a custom theme to a matplotlib figure."""
    ...
  1. Automated Legend Optimizer: A function to automatically position and format legends for optimal readability and aesthetics.
def optimize_legend(ax: matplotlib.axes.Axes, location: str = 'best'):
    """Optimize legend position and formatting."""
    ...
  1. Data-Driven Annotation Placer: A smart function that automatically places annotations on a plot based on data values and available space, avoiding overlaps.
def smart_annotate(ax: matplotlib.axes.Axes, x: List[float], y: List[float], labels: List[str]):
    """Intelligently place annotations on a plot."""
    ...
  1. Interactive Element Generator: Functions to add interactive elements to plots, such as hover tooltips or clickable data points (useful for web-based reports).
def add_hover_tooltip(ax: matplotlib.axes.Axes, x: List[float], y: List[float], tooltip_text: List[str]):
    """Add hover tooltips to data points."""
    ...
  1. Report Metadata Manager: A class to manage and display metadata about the report, such as data sources, last updated date, and version information.
class ReportMetadata:
    def __init__(self, title: str, data_source: str, last_updated: datetime):
        ...
    
    def add_to_figure(self, fig: matplotlib.figure.Figure):
        """Add metadata to the figure as text."""
        ...
  1. Axis Formatter: Functions to automatically format axis labels and ticks for better readability, handling different data types (dates, currencies, large numbers, etc.).
def format_axis(ax: matplotlib.axes.Axes, axis: str, data_type: str):
    """Format axis labels and ticks based on data type."""
    ...
  1. Statistical Annotation Adder: Functions to add statistical information to plots, such as mean lines, confidence intervals, or p-values.
def add_mean_line(ax: matplotlib.axes.Axes, data: List[float], color: str = 'red'):
    """Add a line representing the mean of the data."""
    ...

def add_confidence_interval(ax: matplotlib.axes.Axes, x: List[float], y: List[float], confidence: float = 0.95):
    """Add confidence interval to a line plot."""
    ...
  1. Export Utilities: Functions to export plots in various formats suitable for different types of reports (e.g., high-res for print, web-optimized for online reports).
def export_for_print(fig: matplotlib.figure.Figure, filename: str, dpi: int = 300):
    """Export a figure in high resolution for print."""
    ...

def export_for_web(fig: matplotlib.figure.Figure, filename: str, optimize: bool = True):
    """Export a figure optimized for web display."""
    ...

These additional functions and classes would greatly enhance the capabilities of the module, making it a comprehensive toolkit for creating professional-looking data visualizations and reports.

"""Tools for improving plots look and readability.
version 1.0.0
"""
from typing import Iterable, Optional, Union
from enum import Enum
import matplotlib
import matplotlib.axes
class FontSize(Enum):
SMALL = 8
MEDIUM = 12
BIG = 16
def add_value_labels(
ax: matplotlib.axes.Axes,
spacing: int = 5,
fmt: str = "{:.2f}",
append: Optional[Iterable] = None,
fs: Union[int, str, FontSize] = FontSize.BIG,
show_percentage: bool = False,
total_value: Optional[float] = None,
horizontal: bool = True,
):
"""Add labels to the end of each bar in a bar chart.
Args:
ax: The matplotlib object containing the axes of the plot to annotate.
spacing: The distance between the labels and the bars.
fmt: Format of value labels to display, e.g. "{:.2f}"
append: List of items to be added as suffix to corresponding values
fs: Font size: FontSize enum, int, or 'big|medium|small'
show_percentage: If True, display percentage values instead of absolute values
total_value: Reference value to calculate percentages (required if show_percentage is True)
horizontal: If True, assumes horizontal bar chart, else vertical
Returns:
None
"""
font_size = set_font_size(fs)
if show_percentage and total_value is None:
raise ValueError("total_value must be provided when show_percentage is True")
# For each bar: Place a label
for i, rect in enumerate(ax.patches):
if horizontal:
value = rect.get_width()
pos = (value, rect.get_y() + rect.get_height() / 2)
ha, va = ("left", "center") if value >= 0 else ("right", "center")
space = (spacing, 0)
else:
value = rect.get_height()
pos = (rect.get_x() + rect.get_width() / 2, value)
ha, va = ("center", "bottom") if value >= 0 else ("center", "top")
space = (0, spacing)
if show_percentage:
label = f"{fmt.format((value / total_value) * 100)}%"
else:
label = fmt.format(value)
if append is not None:
try:
label += f" {append[i]}"
except IndexError:
pass # Skip appending if index is out of range
# Create annotation
ax.annotate(
label,
pos,
xytext=space,
textcoords="offset points",
va=va,
ha=ha,
fontsize=font_size,
)
def add_reference_lines_barh(ax: matplotlib.axes.Axes, vals: Iterable):
"""Plot vertical lines on horizontal bar plot and add annotation values to bars.
Args:
ax: The matplotlib object containing the axes of the plot to annotate.
vals: List of values for which lines will be drawn.
"""
BAR_HEIGHT_RATIO = 0.2
LABEL_OFFSET = (0.1, 21)
LABEL_FONTSIZE = 10
if len(vals) != len(ax.patches):
raise ValueError("Length of vals must match the number of bars in the plot")
dy = 1 / len(ax.patches)
for cnt, (rect, x_ref) in enumerate(zip(ax.patches, vals)):
x_val = rect.get_width()
y_value = rect.get_y() + rect.get_height() / 2
y = dy * y_value + dy / 2
ymin = y - dy * BAR_HEIGHT_RATIO
ymax = y + dy * BAR_HEIGHT_RATIO
color = "g" if x_val >= x_ref else "r"
ax.axvline(x=x_ref, ymin=ymin, ymax=ymax, linestyle="--", color=color)
ax.annotate(
f"{x_ref:.2f}",
(x_ref, y_value),
xytext=LABEL_OFFSET,
textcoords="offset points",
va="center",
ha="center",
fontsize=LABEL_FONTSIZE,
color=color,
)
def annotate_boxplot(box_plot):
"""Annotate values for key points of the boxplot."""
ax = box_plot.axes
lines = ax.get_lines()
categories = ax.get_xticks()
MEDIAN_LINE_INTERVAL = 7
MEDIAN_LINE_OFFSET = 5
ANNOTATION_OFFSET = 0.20
for cat in categories:
# every 5th line at the interval of 7 is median line
idx = MEDIAN_LINE_OFFSET + (cat - 1) * MEDIAN_LINE_INTERVAL
y = round(lines[idx].get_ydata()[0], 3)
x = round(lines[idx].get_xdata()[1], 3)
ax.text(
x + ANNOTATION_OFFSET,
y,
f"{y}",
ha="center",
va="center",
fontweight="bold",
size=10,
color="white",
bbox=dict(facecolor="#445A64"),
)
def set_font_size(fs: Union[int, str, FontSize]) -> int:
"""Set font size using descriptive strings, ints, or FontSize enum."""
if isinstance(fs, int):
return fs
if isinstance(fs, FontSize):
return fs.value
if isinstance(fs, str):
return FontSize[fs.upper()].value
raise ValueError("Invalid font size specification")
def annotate_points(df_o, x_col, y_col, label_col, ax, fontsize=18):
"""Annotate points on the plot with provided labels.
Args:
df_o: Dataframe containing the data to be plotted.
x_col: Name of the column containing the x values.
y_col: Name of the column containing the y values
label_col: Name of the column containing the labels
ax: Axes object to be used for plotting
fontsize: (Default value = 18)
"""
for i, point in df_o.iterrows():
ha, va = ("left", "bottom") if i % 2 else ("right", "top")
ax.annotate(
point[label_col],
(point[x_col], point[y_col]),
xytext=(5, 5),
textcoords="offset points",
ha=ha,
va=va,
fontsize=fontsize,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment