Last active
December 27, 2020 00:54
-
-
Save AyrtonB/adb3b249e627efb1687027a8ce3fe52a to your computer and use it in GitHub Desktop.
`AxTransformer` enables conversion from data coordinates to tick locations, `set_date_ticks` allows custom date ranges to be applied to plots (including a seaborn heatmap)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
from collections.abc import Iterable | |
from sklearn import linear_model | |
class AxTransformer: | |
def __init__(self, datetime_vals=False): | |
self.datetime_vals = datetime_vals | |
self.lr = linear_model.LinearRegression() | |
return | |
def process_tick_vals(self, tick_vals): | |
if not isinstance(tick_vals, Iterable) or isinstance(tick_vals, str): | |
tick_vals = [tick_vals] | |
if self.datetime_vals == True: | |
tick_vals = pd.to_datetime(tick_vals).astype(int).values | |
tick_vals = np.array(tick_vals) | |
return tick_vals | |
def fit(self, ax, axis='x'): | |
axis = getattr(ax, f'get_{axis}axis')() | |
tick_locs = axis.get_ticklocs() | |
tick_vals = self.process_tick_vals([label._text for label in axis.get_ticklabels()]) | |
self.lr.fit(tick_vals.reshape(-1, 1), tick_locs) | |
return | |
def transform(self, tick_vals): | |
tick_vals = self.process_tick_vals(tick_vals) | |
tick_locs = self.lr.predict(np.array(tick_vals).reshape(-1, 1)) | |
return tick_locs | |
def set_date_ticks(ax, start_date, end_date, axis='y', date_format='%Y-%m-%d', **date_range_kwargs): | |
dt_rng = pd.date_range(start_date, end_date, **date_range_kwargs) | |
ax_transformer = AxTransformer(datetime_vals=True) | |
ax_transformer.fit(ax, axis=axis) | |
getattr(ax, f'set_{axis}ticks')(ax_transformer.transform(dt_rng)) | |
getattr(ax, f'set_{axis}ticklabels')(dt_rng.strftime(date_format)) | |
ax.tick_params(axis=axis, which='both', bottom=True, top=False, labelbottom=True) | |
return ax |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment