Skip to content

Instantly share code, notes, and snippets.

@razhangwei
Last active August 2, 2020 21:15
Show Gist options
  • Save razhangwei/5554b619a45bf18d08c0b3ca3cf3ddbd to your computer and use it in GitHub Desktop.
Save razhangwei/5554b619a45bf18d08c0b3ca3cf3ddbd to your computer and use it in GitHub Desktop.
A submodule for useful utility functions for ML research projects #Python #Utils
  • logging
  • misc
  • numpy
  • pandas
  • plotting
  • torch
# -*- coding:utf-8 -*-
import os
import os.path as osp
import time
from logging import (
INFO,
getLevelName,
getLogger,
Formatter,
basicConfig,
StreamHandler,
)
import sys
LOG_DEFAULT_FORMAT = "[ %(asctime)s][%(module)s.%(funcName)s] %(message)s"
LOG_DEFAULT_LEVEL = INFO
def strftime(t=None):
"""Get string of current time"""
return time.strftime("%Y%m%d-%H%M%S", time.localtime(t or time.time()))
def init_logging(logging_dir, filename=None, level=None, log_format=None):
if not osp.exists(logging_dir):
os.makedirs(logging_dir)
filename = filename or strftime() + ".log"
log_format = log_format or LOG_DEFAULT_FORMAT
global LOG_DEFAULT_LEVEL
if isinstance(level, str):
level = getLevelName(level.upper())
elif level is None:
level = LOG_DEFAULT_LEVEL
LOG_DEFAULT_LEVEL = level
basicConfig(
filename=osp.join(logging_dir, filename),
format=log_format,
level=level,
)
def get_logger(name, level=None, log_format=None, print_to_std=True):
"""
Get the logger
level: if None, then use default=INFO
log_format: if None, use default format
print_to_std: default=True
"""
if level is None:
level = LOG_DEFAULT_LEVEL
elif isinstance(level, str):
level = getLevelName(level)
if not log_format:
log_format = LOG_DEFAULT_FORMAT
logger = getLogger(name)
logger.setLevel(level)
if print_to_std:
handler = StreamHandler(sys.stdout)
handler.setLevel(level)
handler.setFormatter(Formatter(log_format))
logger.addHandler(handler)
return logger
import json
import os
import time
import shutil
def set_rand_seed(seed):
import numpy as np
import torch
np.random.seed(seed)
torch.manual_seed(seed)
def set_printoptions(precision=4, linewidth=160):
import numpy as np
import torch
np.set_printoptions(precision=precision, linewidth=linewidth)
torch.set_printoptions(precision=precision, linewidth=linewidth)
class RelativeChangeMonitor(object):
def __init__(self, tol):
self.tol = tol
# self._best_loss = float('inf')
# self._curr_loss = float('inf')
self._losses = []
self._best_loss = float('inf')
@property
def save(self):
return len(self._losses) > 0 and self._losses[-1] == self._best_loss
@property
def stop(self):
return len(self._losses) > 1 and abs(
(self._losses[-1] - self._losses[-2]) / self._best_loss) < self.tol
def register(self, loss):
self._losses.append(loss)
self._best_loss = min(self._best_loss, loss)
class EarlyStoppingMonitor(object):
def __init__(self, patience):
self._patience = patience
self._best_loss = float('inf')
self._curr_loss = float('inf')
self._n_fails = 0
@property
def save(self):
return self._curr_loss == self._best_loss
@property
def stop(self):
return self._n_fails > self._patience
def register(self, loss):
self._curr_loss = loss
if loss < self._best_loss:
self._best_loss = loss
self._n_fails = 0
else:
self._n_fails += 1
class Timer(object):
def __init__(self, name=None):
self.name = name
def __enter__(self):
self.tstart = time.time()
def __exit__(self, type, value, traceback):
if self.name:
print("[{}] ".format(self.name), end="")
dt = time.time() - self.tstart
if dt < 60:
print("Elapsed: {:.4f} sec.".format(dt))
elif dt < 3600:
print("Elapsed: {:.4f} min.".format(dt / 60))
elif dt < 86400:
print("Elapsed: {:.4f} hour.".format(dt / 3600))
else:
print("Elapsed: {:.4f} day.".format(dt / 86400))
def makedirs(dirs):
assert isinstance(dirs, list), "Argument dirs needs to be a list"
for dir in dirs:
if not os.path.isdir(dir):
os.makedirs(dir)
def export_json(obj, path):
with open(path, "w") as fout:
json.dump(obj, fout, indent=4)
def export_csv(df, path, append=False, index=False):
if not os.path.isdir(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
mode = "a" if append else "w"
with open(path, mode) as f:
df.to_csv(f, header=f.tell() == 0, index=index)
def counting_proc_to_event_seq(count_proc):
"""Convert a counting process sample to event sequence
Args:
count_proc (list of ndarray): each array in the list contains the
timestamps of events occurred on that dimension.
Returns:
(list of 2-tuples): each tuple is of (t, c), where c denotes the event
type
"""
seq = []
for i, ts in enumerate(count_proc):
seq += [(t, i) for t in ts]
seq = sorted(seq, key=lambda x: x[0])
return seq
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def save_checkpoint(state, output_folder, is_best, filename='checkpoint.tar'):
import torch
torch.save(state, os.path.join(output_folder, filename))
if is_best:
shutil.copyfile(os.path.join(output_folder, filename),
os.path.join(output_folder, 'model_best.tar'))
def get_freer_gpu(by="n_proc"):
"""Return the GPU index which has the largest available memory
Returns:
int: the index of selected GPU.
"""
from pynvml import (nvmlInit, nvmlDeviceGetCount,
nvmlDeviceGetHandleByIndex,
nvmlDeviceGetComputeRunningProcesses,
nvmlDeviceGetMemoryInfo)
nvmlInit()
n_devices = nvmlDeviceGetCount()
gpu_id, gpu_state = None, None
for i in range(0, n_devices):
handle = nvmlDeviceGetHandleByIndex(i)
if by == "n_proc":
temp = -len(nvmlDeviceGetComputeRunningProcesses(handle))
elif by == "free_mem":
temp = nvmlDeviceGetMemoryInfo(handle).free
else:
raise ValueError("`by` can only be 'n_proc', 'free_mem'.")
if gpu_id is None or gpu_state < temp:
gpu_id, gpu_state = i, temp
return gpu_id
def savefig(fig, path, save_pickle=False):
"""save matplotlib figure
Args:
fig (matplotlib.figure.Figure): figure object
path (str): [description]
save_pickle (bool, optional): Defaults to True. Whether to pickle the
figure object as well.
"""
fig.savefig(path, bbox_inches="tight")
if save_pickle:
import matplotlib
import pickle
# the `inline` of IPython will fail the pickle/unpickle; if so, switch
# the backend temporarily
if "inline" in matplotlib.get_backend():
raise ("warning: the `inline` of IPython will fail the pickle/"
"unpickle. Please use `matplotlib.use` to switch to other "
"backend.")
else:
with open(path + ".pkl", 'wb') as f:
pickle.dump(fig, f)
import numpy as np
from scipy.sparse import csr_matrix, csc_matrix
from scipy.special import logsumexp
def normalize_matrix(A, axis=1):
"""
Normalize A
Parameters
----------
A : numpy array or csr_matrix, shape (n, m)
The matrix to be normalize
axis : int
The axis to be normalize.
"""
if isinstance(A, np.ndarray):
axis_sums = A.sum(axis=axis)
A = A / np.expand_dims(axis_sums, axis=axis)
elif isinstance(A, csr_matrix):
if axis != 1:
raise ValueError("axis must be 1 for csr_matrix.")
row_sums = np.array(A.sum(axis=1))[:, 0]
row_indices, _ = A.nonzero()
A.data /= row_sums[row_indices]
elif isinstance(A, csc_matrix):
if axis != 0:
raise ValueError("axis must be 0 for csc_matrix.")
col_sums = np.array(A.sum(axis=0))[:, 0]
_, col_indices = A.nonzero()
A.data /= col_sums[col_indices]
else:
raise NotImplementedError("Not implemented for type=%s" % type(A))
import numpy as np
import pandas as pd
def explode(df, col: str):
"""Explode an array-valued column
Args:
df (DataFrame):
col (str):
Returns:
DataFrame: the exploded DataFrame, whose length equal to the total
length of the `col` column.
"""
# Flatten columns of lists
col_flat = [x for arr in df[col] for x in arr]
# Row numbers to repeat
lens = df[col].apply(len)
ilocations = np.arange(len(df)).repeat(lens)
# Replicate rows and add flattened column of lists
col_indices = [i for i, c in enumerate(df.columns) if c != col]
new_df = df.iloc[ilocations, col_indices].copy()
new_df[col] = col_flat
return new_df
def applyParallel(dfGrouped, func):
"""parallel apply after group
Args:
dfGrouped (DataFrameGroupBy): the object after calling `groupby(...)`
func (Callable): the function to apply
Returns:
List: results, one for each group key.
"""
from multiprocessing import Pool
with Pool() as p:
ret_list = p.map(func, (group for name, group in dfGrouped))
return ret_list
def multiindex_pivot(df, index, column, values):
"""pivot with index using multiple column
> From <https://github.com/pandas-dev/pandas/issues/23955>
Args:
df (DataFrame): [description]
index (Union[str, List[str]]):
column (str):
values (Union[str, List[str]]):
Returns:
DataFrame: [description]
"""
if pd.__version__ >= "1.1":
print(
"[warning]: Since version 1.1.0, Pandas has supported pivot with multiple indices."
)
assert isinstance(column, str), "column needs to be a string."
if isinstance(index, str):
index = [index]
if isinstance(values, str):
values = [values]
tuples_index = list(map(tuple, df[index].values))
df = df.assign(tuples_index=tuples_index)
df = df.pivot(index="tuples_index", columns=columns, values=values)
new_index = pd.MultiIndex.from_tuples(df.index, names=index)
df.index = new_index
return df
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
def savefig(fig, path, save_pickle=False):
"""save matplotlib figure
Args:
fig (matplotlib.figure.Figure): figure object
path (str): [description]
save_pickle (bool, optional): Defaults to True. Whether to pickle the
figure object as well.
"""
fig.savefig(path, bbox_inches="tight", transparent=True)
if save_pickle:
import matplotlib
import pickle
# the `inline` of IPython will fail the pickle/unpickle; if so, switch
# the backend temporarily
if "inline" in matplotlib.get_backend():
raise (
"warning: the `inline` of IPython will fail the pickle/"
"unpickle. Please use `matplotlib.use` to switch to other "
"backend."
)
else:
with open(path + ".pkl", "wb") as f:
pickle.dump(fig, f)
def config_mpl_style(rc=None, scale=1):
my_rc = {
"font.family": "serif",
"axes.titlesize": 24 * scale,
"axes.labelweight": "bold",
"axes.labelsize": 20 * scale,
"xtick.labelsize": 16 * scale,
"ytick.labelsize": 16 * scale,
"legend.title_fontsize": 16 * scale,
"legend.fontsize": 14 * 1,
"lines.markersize": 10,
"lines.linewidth": 3,
"pdf.fonttype": 42,
"ps.fonttype": 42,
}
if rc is not None:
assert isinstance(rc, dict)
my_rc.update(rc)
plt.style.use(my_rc)
def heatmap(
A,
B=None,
labels=None,
title="",
table_like=False,
color_bar=False,
ax=None,
):
"""Draw heatmap along with dots diagram for visualizing a weight matrix."""
if ax is None:
ax = plt.gca()
# square shaped
ax.set_aspect("equal", "box")
# turn off the frame
ax.set_frame_on(False)
# want a more natural, table-like display
if table_like:
ax.invert_yaxis()
ax.xaxis.tick_top()
# put the major ticks at the middle of each cell
ax.set_yticks(np.arange(A.shape[0]) + 0.5, minor=False)
ax.set_xticks(np.arange(A.shape[1]) + 0.5, minor=False)
# turn off all ticks
ax.xaxis.set_tick_params(top=False, bottom=False)
ax.yaxis.set_tick_params(left=False, right=False)
# add labels
if labels is not None:
ax.set_xticklabels(labels, rotation=90)
ax.set_yticklabels(labels)
ax.set_title(title)
# draw heatmap
A_normed = (A - A.min()) / (A.max() - A.min())
heatmap = ax.pcolor(A_normed, cmap=plt.cm.Greys)
# add dots
if B is not None:
assert B.shape == A.shape
for (y, x), w in np.ndenumerate(B):
r = 0.35 * np.sqrt(w / B.max())
circle = plt.Circle(
(x + 0.5, y + 0.5), radius=r, color="darkgreen"
)
ax.add_artist(circle)
# add colorbar
if color_bar:
ax.get_figure().colorbar(heatmap, ticks=[0, 1], orientation="vertical")
def plot_cross_validation(
x,
scores,
show_error=True,
allow_missing=False,
xlabel="",
ylabel="",
title="",
xscale="log",
yscale=None,
ax=None,
):
"""Plot cross-validation curve with respect to some parameters.
Parameters
----------
x : array-like, shape (n_params, )
The values of parameter
scores: array-like, shape (n_params, n_folds)
Each row store the CV results for one parameter value. Note that it may
contain np.nan
"""
if ax is None:
ax = plt.gca()
# axes style
ax.grid()
# plot curve
if allow_missing:
y = np.nanmean(scores, axis=1)
err = np.nanstd(scores, axis=1)
else:
idx = ~np.isnan(scores).any(axis=1)
x = np.asarray(x)[idx]
y = np.mean(scores[idx], axis=1)
err = np.std(scores[idx], axis=1)
# style
fmt = "o-"
# plot curve
if show_error:
ax.errorbar(x, y, err, fmt=fmt)
else:
ax.plot(x, y, fmt=fmt)
# set labels and title
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_title(title)
# set axis limit
if xscale is None:
ax.set_xlim(xmin=0, xmax=max(x))
elif xscale == "log":
ax.set_xlim(xmin=min(x) / 2, xmax=max(x) * 2)
# set axis scale
if xscale is not None:
ax.set_xscale(xscale)
if yscale is not None:
ax.set_yscale(yscale)
def ternaryplot(
root_proba,
colors=["red", "green", "blue"],
markers=["s", "D", "o"],
fontsize=20,
):
assert (min(np.reshape(root_proba, np.size(root_proba))) >= 0) & (
max(np.reshape(root_proba, np.size(root_proba))) <= 1
)
import ternary
figure, tax = ternary.figure(scale=1)
figure.set_size_inches(10, 10)
tax.boundary(linewidth=2.0)
tax.gridlines(color="blue", multiple=0.05, linewidth=0.5)
tax.bottom_axis_label("Normal", fontsize=fontsize, color="brown")
tax.right_axis_label("early MCI", fontsize=fontsize, color="brown")
tax.left_axis_label("clinical MCI", fontsize=fontsize, color="brown")
# plot the prediction boundary
p = (1.0 / 3, 1.0 / 3, 1.0 / 3)
p1 = (0, 0.5, 0.5)
p2 = (0.5, 0, 0.5)
p3 = (0.5, 0.5, 0)
tax.line(p, p1, linestyle="--", color="brown", linewidth=3)
tax.line(p, p2, linestyle="--", color="brown", linewidth=3)
tax.line(p, p3, linestyle="--", color="brown", linewidth=3)
# plot scatter plot of the points
tax.scatter(
root_proba, s=1, linewidth=3.5, marker=markers[0], color=colors[0]
)
tax.ticks(axis="lbr", multiple=0.1, linewidth=1)
tax.clear_matplotlib_ticks()
tax.show()
return figure, tax
def barplot(
x,
hue,
hue_labels=None,
xlabel=None,
ylabel=None,
title=None,
ax=None,
cmap_name="Accent",
):
"""Plot the proportion of hue for each value of x.
Parameters
----------
x : array-like
Group id.
hue : array-like
Hue id
"""
if ax is None:
ax = plt.gca()
n_group = len(set(x))
n_hue = len(set(hue))
hue2id = {h: i for i, h in enumerate(list(set(hue)))}
data = np.zeros([n_group, n_hue])
for i in range(len(x)):
data[x[i], hue2id[hue[i]]] += 1
data = data / data.sum(axis=1)[:, None]
# set the style
default_color_list = ["lightgreen", "dodgerblue", "orangered", "black"]
if n_hue <= len(default_color_list):
colors = default_color_list[:n_hue]
else:
colors = plt.cm.get_cmap(cmap_name)(np.linspace(0, 1, n_hue))
width = 0.5 / n_hue
left = np.arange(n_group)
for i in range(n_hue):
_ = ax.bar(left + width * i, data[:, i], width, color=colors[i])
ax.set_ylabel(ylabel)
ax.set_xlabel(xlabel)
ax.set_xticks(left + n_hue * width / 2)
ax.set_xticklabels(map(str, range(n_group)))
# add ticks
ax.yaxis.set_tick_params(left=True, right=True)
ax.xaxis.set_tick_params(top=True, bottom=True)
ax.set_xlim(xmin=-1 + width * n_hue, xmax=n_group)
# add legend
if hue_labels is not None and len(hue_labels) == n_hue:
ax.legend(
labels=hue_labels, loc="center left", bbox_to_anchor=(1, 0.5)
)
ax.set_title(title)
def networkplot(
weights,
labels=None,
max_node_size=3000,
min_node_size=100,
max_width=10,
min_width=1,
arrowsize=80,
colorbar=True,
x_margins=0,
scale=1.0,
ax=None,
):
import networkx as nx
assert weights.shape[0] == weights.shape[1]
assert labels is None or weights.shape[0] == len(labels)
if not ax:
ax = plt.figure().gca()
G = nx.DiGraph()
for i in range(weights.shape[0]):
for j in range(weights.shape[1]):
if weights[i][j] != 0:
G.add_edge(i, j, weight=weights[i][j])
pos = nx.spring_layout(G, seed=0)
# node properties
node_size = np.sqrt(np.abs(weights).sum(0)) # weight sum of outgoing edges
node_size *= max_node_size / node_size.max()
node_size = np.maximum(node_size, min_node_size)
# node_size = max_node_size
# edge properties
edge_color = np.asarray([e[2]["weight"] for e in G.edges(data=True)])
width = np.sqrt(np.abs(edge_color) / np.abs(edge_color).max()) * max_width
width = np.maximum(width, min_width)
if weights.min() >= 0:
edge_vmax = weights.max()
edge_vmin = 0
cmap = plt.cm.Reds
elif weights.max() <= 0:
edge_vmax = 0
edge_vmin = weights.min()
cmap = plt.cm.Blues_r
else:
edge_vmax = np.abs(weights).max()
edge_vmin = -edge_vmax
# edge_vmax = weights.max()
# edge_vmin = weights.min()
cmap = plt.cm.RdYlBu_r
nx.draw_networkx(
G,
pos,
node_size=node_size,
node_color="#008D0A",
labels={k: labels[k] for k in pos},
ax=ax,
)
# change arrow size
arrowsize = width / width.max() * arrowsize
for i, e in enumerate(G.edges()):
nx.draw_networkx_edges(
G,
pos,
edgelist=[e],
arrowsize=arrowsize[i],
width=width[i],
edge_color=edge_color[i : i + 1],
edge_cmap=cmap,
edge_vmin=edge_vmin,
edge_vmax=edge_vmax,
node_size=node_size,
)
if colorbar:
sm = mpl.cm.ScalarMappable(
cmap=cmap, norm=plt.Normalize(vmin=edge_vmin, vmax=edge_vmax)
)
sm._A = []
plt.colorbar(sm, ax=ax)
ax.margins(x=x_margins)
ax.axis("off")
return ax
import os
import random
import shutil
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
class ResidualLayer(nn.Module):
def __init__(
self, in_features, out_features, hidden_size=0, activation=None
):
super().__init__()
hidden_size = hidden_size or in_features
self.net1 = nn.Sequential(
nn.Linear(in_features, hidden_size),
activation or nn.ReLU(),
nn.Linear(hidden_size, out_features),
)
if hidden_size == out_features:
self.net2 = lambda x: x
else:
self.net2 = nn.Linear(in_features, out_features, bias=False)
def forward(self, x):
return self.net1(x) + self.net2(x)
def save_checkpoint(state, output_folder, is_best, filename="checkpoint.tar"):
import torch
torch.save(state, os.path.join(output_folder, filename))
if is_best:
shutil.copyfile(
os.path.join(output_folder, filename),
os.path.join(output_folder, "model_best.tar"),
)
def split_dataset(dataset, ratio: float):
n = len(dataset)
lengths = [int(n * ratio), n - int(n * ratio)]
return torch.utils.data.random_split(dataset, lengths)
def split_dataloader(dataloader, ratio: float):
dataset = dataloader.dataset
n = len(dataset)
lengths = [int(n * ratio), n - int(n * ratio)]
datasets = torch.utils.data.random_split(dataset, lengths)
copied_fields = ["batch_size", "num_workers", "collate_fn", "drop_last"]
dataloaders = []
for d in datasets:
dataloaders.append(
DataLoader(
dataset=d, **{k: getattr(dataloader, k) for k in copied_fields}
)
)
return tuple(dataloaders)
class KeyBucketedBatchSampler(torch.utils.data.Sampler):
"""Pseduo bucketed batch sampler.
Sample in a way that
Args:
keys (List[int]): keys by which the same or nearby keys are allocated
in the same or nearby batches.
batch_size (int):
drop_last (bool, optional): Whether to drop the last incomplete batch.
Defaults to False.
shuffle_same_key (bool, optional): Whether to shuffle the instances of
the same keys. Defaults to False.
"""
def __init__(
self, keys, batch_size, drop_last=False, shuffle_same_key=True
):
self.keys = keys
self.batch_size = batch_size
self.drop_last = drop_last
self.shuffle_same_key = shuffle_same_key
# bucket sort; maintain random order inside each bucket
buckets = {}
for i, key in enumerate(self.keys):
if key not in buckets:
buckets[key] = [i]
else:
buckets[key].append(i)
self.buckets = buckets
def __iter__(self):
indices = []
for key in sorted(self.buckets.keys()):
v = self.buckets[key]
if self.shuffle_same_key:
random.shuffle(v)
indices += v
index_batches = []
for i in range(0, len(indices), self.batch_size):
j = min(i + self.batch_size, len(indices))
index_batches.append(indices[i:j])
del indices
if self.drop_last and len(index_batches[-1]) < self.batch_size:
index_batches = index_batches[:-1]
random.shuffle(index_batches)
for indices in index_batches:
yield indices
def __len__(self):
if self.drop_last:
return len(self.keys) // self.batch_size
else:
return (len(self.keys) + self.batch_size - 1) // self.batch_size
def convert_to_bucketed_dataloader(
dataloader: DataLoader, key_fn, shuffle_same_key=True
):
"""Convert a data loader to bucketed data loader with a given keys.
Args:
dataloader (DataLoader):
key_fn (Callable]): function to extract keys used for constructing
the bucketed data loader; should be of the same key as the
dataset.
shuffle_same_key (bool, optional): Whether to shuffle the instances of
the same keys. Defaults to False.
Returns:
DataLoader:
"""
assert (
dataloader.batch_size is not None
), "The `batch_size` must be present for the input dataloader"
dataset = dataloader.dataset
keys = [key_fn(dataset[i]) for i in range(len(dataset))]
batch_sampler = KeyBucketedBatchSampler(
keys,
batch_size=dataloader.batch_size,
drop_last=dataloader.drop_last,
shuffle_same_key=shuffle_same_key,
)
return DataLoader(
dataset, batch_sampler=batch_sampler, collate_fn=dataloader.collate_fn
)
def generate_sequence_mask(lengths, device=None):
"""
Args:
lengths (LongTensor): 1-D
Returns:
BoolTensor: [description]
"""
index = torch.arange(lengths.max(), device=device or lengths.device)
return index.unsqueeze(0) < lengths.unsqueeze(1)
def set_eval_mode(module, root=True):
if root:
module.train()
name = module.__class__.__name__
if "Dropout" in name or "BatchNorm" in name:
module.training = False
for child_module in module.children():
set_eval_mode(child_module, False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment