- logging
- misc
- numpy
- pandas
- plotting
- torch
Last active
August 2, 2020 21:15
-
-
Save razhangwei/5554b619a45bf18d08c0b3ca3cf3ddbd to your computer and use it in GitHub Desktop.
A submodule for useful utility functions for ML research projects #Python #Utils
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding:utf-8 -*- | |
import os | |
import os.path as osp | |
import time | |
from logging import ( | |
INFO, | |
getLevelName, | |
getLogger, | |
Formatter, | |
basicConfig, | |
StreamHandler, | |
) | |
import sys | |
LOG_DEFAULT_FORMAT = "[ %(asctime)s][%(module)s.%(funcName)s] %(message)s" | |
LOG_DEFAULT_LEVEL = INFO | |
def strftime(t=None): | |
"""Get string of current time""" | |
return time.strftime("%Y%m%d-%H%M%S", time.localtime(t or time.time())) | |
def init_logging(logging_dir, filename=None, level=None, log_format=None): | |
if not osp.exists(logging_dir): | |
os.makedirs(logging_dir) | |
filename = filename or strftime() + ".log" | |
log_format = log_format or LOG_DEFAULT_FORMAT | |
global LOG_DEFAULT_LEVEL | |
if isinstance(level, str): | |
level = getLevelName(level.upper()) | |
elif level is None: | |
level = LOG_DEFAULT_LEVEL | |
LOG_DEFAULT_LEVEL = level | |
basicConfig( | |
filename=osp.join(logging_dir, filename), | |
format=log_format, | |
level=level, | |
) | |
def get_logger(name, level=None, log_format=None, print_to_std=True): | |
""" | |
Get the logger | |
level: if None, then use default=INFO | |
log_format: if None, use default format | |
print_to_std: default=True | |
""" | |
if level is None: | |
level = LOG_DEFAULT_LEVEL | |
elif isinstance(level, str): | |
level = getLevelName(level) | |
if not log_format: | |
log_format = LOG_DEFAULT_FORMAT | |
logger = getLogger(name) | |
logger.setLevel(level) | |
if print_to_std: | |
handler = StreamHandler(sys.stdout) | |
handler.setLevel(level) | |
handler.setFormatter(Formatter(log_format)) | |
logger.addHandler(handler) | |
return logger |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import os | |
import time | |
import shutil | |
def set_rand_seed(seed): | |
import numpy as np | |
import torch | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
def set_printoptions(precision=4, linewidth=160): | |
import numpy as np | |
import torch | |
np.set_printoptions(precision=precision, linewidth=linewidth) | |
torch.set_printoptions(precision=precision, linewidth=linewidth) | |
class RelativeChangeMonitor(object): | |
def __init__(self, tol): | |
self.tol = tol | |
# self._best_loss = float('inf') | |
# self._curr_loss = float('inf') | |
self._losses = [] | |
self._best_loss = float('inf') | |
@property | |
def save(self): | |
return len(self._losses) > 0 and self._losses[-1] == self._best_loss | |
@property | |
def stop(self): | |
return len(self._losses) > 1 and abs( | |
(self._losses[-1] - self._losses[-2]) / self._best_loss) < self.tol | |
def register(self, loss): | |
self._losses.append(loss) | |
self._best_loss = min(self._best_loss, loss) | |
class EarlyStoppingMonitor(object): | |
def __init__(self, patience): | |
self._patience = patience | |
self._best_loss = float('inf') | |
self._curr_loss = float('inf') | |
self._n_fails = 0 | |
@property | |
def save(self): | |
return self._curr_loss == self._best_loss | |
@property | |
def stop(self): | |
return self._n_fails > self._patience | |
def register(self, loss): | |
self._curr_loss = loss | |
if loss < self._best_loss: | |
self._best_loss = loss | |
self._n_fails = 0 | |
else: | |
self._n_fails += 1 | |
class Timer(object): | |
def __init__(self, name=None): | |
self.name = name | |
def __enter__(self): | |
self.tstart = time.time() | |
def __exit__(self, type, value, traceback): | |
if self.name: | |
print("[{}] ".format(self.name), end="") | |
dt = time.time() - self.tstart | |
if dt < 60: | |
print("Elapsed: {:.4f} sec.".format(dt)) | |
elif dt < 3600: | |
print("Elapsed: {:.4f} min.".format(dt / 60)) | |
elif dt < 86400: | |
print("Elapsed: {:.4f} hour.".format(dt / 3600)) | |
else: | |
print("Elapsed: {:.4f} day.".format(dt / 86400)) | |
def makedirs(dirs): | |
assert isinstance(dirs, list), "Argument dirs needs to be a list" | |
for dir in dirs: | |
if not os.path.isdir(dir): | |
os.makedirs(dir) | |
def export_json(obj, path): | |
with open(path, "w") as fout: | |
json.dump(obj, fout, indent=4) | |
def export_csv(df, path, append=False, index=False): | |
if not os.path.isdir(os.path.dirname(path)): | |
os.makedirs(os.path.dirname(path)) | |
mode = "a" if append else "w" | |
with open(path, mode) as f: | |
df.to_csv(f, header=f.tell() == 0, index=index) | |
def counting_proc_to_event_seq(count_proc): | |
"""Convert a counting process sample to event sequence | |
Args: | |
count_proc (list of ndarray): each array in the list contains the | |
timestamps of events occurred on that dimension. | |
Returns: | |
(list of 2-tuples): each tuple is of (t, c), where c denotes the event | |
type | |
""" | |
seq = [] | |
for i, ts in enumerate(count_proc): | |
seq += [(t, i) for t in ts] | |
seq = sorted(seq, key=lambda x: x[0]) | |
return seq | |
class AverageMeter(object): | |
"""Computes and stores the average and current value""" | |
def __init__(self): | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
def save_checkpoint(state, output_folder, is_best, filename='checkpoint.tar'): | |
import torch | |
torch.save(state, os.path.join(output_folder, filename)) | |
if is_best: | |
shutil.copyfile(os.path.join(output_folder, filename), | |
os.path.join(output_folder, 'model_best.tar')) | |
def get_freer_gpu(by="n_proc"): | |
"""Return the GPU index which has the largest available memory | |
Returns: | |
int: the index of selected GPU. | |
""" | |
from pynvml import (nvmlInit, nvmlDeviceGetCount, | |
nvmlDeviceGetHandleByIndex, | |
nvmlDeviceGetComputeRunningProcesses, | |
nvmlDeviceGetMemoryInfo) | |
nvmlInit() | |
n_devices = nvmlDeviceGetCount() | |
gpu_id, gpu_state = None, None | |
for i in range(0, n_devices): | |
handle = nvmlDeviceGetHandleByIndex(i) | |
if by == "n_proc": | |
temp = -len(nvmlDeviceGetComputeRunningProcesses(handle)) | |
elif by == "free_mem": | |
temp = nvmlDeviceGetMemoryInfo(handle).free | |
else: | |
raise ValueError("`by` can only be 'n_proc', 'free_mem'.") | |
if gpu_id is None or gpu_state < temp: | |
gpu_id, gpu_state = i, temp | |
return gpu_id | |
def savefig(fig, path, save_pickle=False): | |
"""save matplotlib figure | |
Args: | |
fig (matplotlib.figure.Figure): figure object | |
path (str): [description] | |
save_pickle (bool, optional): Defaults to True. Whether to pickle the | |
figure object as well. | |
""" | |
fig.savefig(path, bbox_inches="tight") | |
if save_pickle: | |
import matplotlib | |
import pickle | |
# the `inline` of IPython will fail the pickle/unpickle; if so, switch | |
# the backend temporarily | |
if "inline" in matplotlib.get_backend(): | |
raise ("warning: the `inline` of IPython will fail the pickle/" | |
"unpickle. Please use `matplotlib.use` to switch to other " | |
"backend.") | |
else: | |
with open(path + ".pkl", 'wb') as f: | |
pickle.dump(fig, f) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from scipy.sparse import csr_matrix, csc_matrix | |
from scipy.special import logsumexp | |
def normalize_matrix(A, axis=1): | |
""" | |
Normalize A | |
Parameters | |
---------- | |
A : numpy array or csr_matrix, shape (n, m) | |
The matrix to be normalize | |
axis : int | |
The axis to be normalize. | |
""" | |
if isinstance(A, np.ndarray): | |
axis_sums = A.sum(axis=axis) | |
A = A / np.expand_dims(axis_sums, axis=axis) | |
elif isinstance(A, csr_matrix): | |
if axis != 1: | |
raise ValueError("axis must be 1 for csr_matrix.") | |
row_sums = np.array(A.sum(axis=1))[:, 0] | |
row_indices, _ = A.nonzero() | |
A.data /= row_sums[row_indices] | |
elif isinstance(A, csc_matrix): | |
if axis != 0: | |
raise ValueError("axis must be 0 for csc_matrix.") | |
col_sums = np.array(A.sum(axis=0))[:, 0] | |
_, col_indices = A.nonzero() | |
A.data /= col_sums[col_indices] | |
else: | |
raise NotImplementedError("Not implemented for type=%s" % type(A)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
def explode(df, col: str): | |
"""Explode an array-valued column | |
Args: | |
df (DataFrame): | |
col (str): | |
Returns: | |
DataFrame: the exploded DataFrame, whose length equal to the total | |
length of the `col` column. | |
""" | |
# Flatten columns of lists | |
col_flat = [x for arr in df[col] for x in arr] | |
# Row numbers to repeat | |
lens = df[col].apply(len) | |
ilocations = np.arange(len(df)).repeat(lens) | |
# Replicate rows and add flattened column of lists | |
col_indices = [i for i, c in enumerate(df.columns) if c != col] | |
new_df = df.iloc[ilocations, col_indices].copy() | |
new_df[col] = col_flat | |
return new_df | |
def applyParallel(dfGrouped, func): | |
"""parallel apply after group | |
Args: | |
dfGrouped (DataFrameGroupBy): the object after calling `groupby(...)` | |
func (Callable): the function to apply | |
Returns: | |
List: results, one for each group key. | |
""" | |
from multiprocessing import Pool | |
with Pool() as p: | |
ret_list = p.map(func, (group for name, group in dfGrouped)) | |
return ret_list | |
def multiindex_pivot(df, index, column, values): | |
"""pivot with index using multiple column | |
> From <https://github.com/pandas-dev/pandas/issues/23955> | |
Args: | |
df (DataFrame): [description] | |
index (Union[str, List[str]]): | |
column (str): | |
values (Union[str, List[str]]): | |
Returns: | |
DataFrame: [description] | |
""" | |
if pd.__version__ >= "1.1": | |
print( | |
"[warning]: Since version 1.1.0, Pandas has supported pivot with multiple indices." | |
) | |
assert isinstance(column, str), "column needs to be a string." | |
if isinstance(index, str): | |
index = [index] | |
if isinstance(values, str): | |
values = [values] | |
tuples_index = list(map(tuple, df[index].values)) | |
df = df.assign(tuples_index=tuples_index) | |
df = df.pivot(index="tuples_index", columns=columns, values=values) | |
new_index = pd.MultiIndex.from_tuples(df.index, names=index) | |
df.index = new_index | |
return df |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib as mpl | |
import matplotlib.pyplot as plt | |
import numpy as np | |
def savefig(fig, path, save_pickle=False): | |
"""save matplotlib figure | |
Args: | |
fig (matplotlib.figure.Figure): figure object | |
path (str): [description] | |
save_pickle (bool, optional): Defaults to True. Whether to pickle the | |
figure object as well. | |
""" | |
fig.savefig(path, bbox_inches="tight", transparent=True) | |
if save_pickle: | |
import matplotlib | |
import pickle | |
# the `inline` of IPython will fail the pickle/unpickle; if so, switch | |
# the backend temporarily | |
if "inline" in matplotlib.get_backend(): | |
raise ( | |
"warning: the `inline` of IPython will fail the pickle/" | |
"unpickle. Please use `matplotlib.use` to switch to other " | |
"backend." | |
) | |
else: | |
with open(path + ".pkl", "wb") as f: | |
pickle.dump(fig, f) | |
def config_mpl_style(rc=None, scale=1): | |
my_rc = { | |
"font.family": "serif", | |
"axes.titlesize": 24 * scale, | |
"axes.labelweight": "bold", | |
"axes.labelsize": 20 * scale, | |
"xtick.labelsize": 16 * scale, | |
"ytick.labelsize": 16 * scale, | |
"legend.title_fontsize": 16 * scale, | |
"legend.fontsize": 14 * 1, | |
"lines.markersize": 10, | |
"lines.linewidth": 3, | |
"pdf.fonttype": 42, | |
"ps.fonttype": 42, | |
} | |
if rc is not None: | |
assert isinstance(rc, dict) | |
my_rc.update(rc) | |
plt.style.use(my_rc) | |
def heatmap( | |
A, | |
B=None, | |
labels=None, | |
title="", | |
table_like=False, | |
color_bar=False, | |
ax=None, | |
): | |
"""Draw heatmap along with dots diagram for visualizing a weight matrix.""" | |
if ax is None: | |
ax = plt.gca() | |
# square shaped | |
ax.set_aspect("equal", "box") | |
# turn off the frame | |
ax.set_frame_on(False) | |
# want a more natural, table-like display | |
if table_like: | |
ax.invert_yaxis() | |
ax.xaxis.tick_top() | |
# put the major ticks at the middle of each cell | |
ax.set_yticks(np.arange(A.shape[0]) + 0.5, minor=False) | |
ax.set_xticks(np.arange(A.shape[1]) + 0.5, minor=False) | |
# turn off all ticks | |
ax.xaxis.set_tick_params(top=False, bottom=False) | |
ax.yaxis.set_tick_params(left=False, right=False) | |
# add labels | |
if labels is not None: | |
ax.set_xticklabels(labels, rotation=90) | |
ax.set_yticklabels(labels) | |
ax.set_title(title) | |
# draw heatmap | |
A_normed = (A - A.min()) / (A.max() - A.min()) | |
heatmap = ax.pcolor(A_normed, cmap=plt.cm.Greys) | |
# add dots | |
if B is not None: | |
assert B.shape == A.shape | |
for (y, x), w in np.ndenumerate(B): | |
r = 0.35 * np.sqrt(w / B.max()) | |
circle = plt.Circle( | |
(x + 0.5, y + 0.5), radius=r, color="darkgreen" | |
) | |
ax.add_artist(circle) | |
# add colorbar | |
if color_bar: | |
ax.get_figure().colorbar(heatmap, ticks=[0, 1], orientation="vertical") | |
def plot_cross_validation( | |
x, | |
scores, | |
show_error=True, | |
allow_missing=False, | |
xlabel="", | |
ylabel="", | |
title="", | |
xscale="log", | |
yscale=None, | |
ax=None, | |
): | |
"""Plot cross-validation curve with respect to some parameters. | |
Parameters | |
---------- | |
x : array-like, shape (n_params, ) | |
The values of parameter | |
scores: array-like, shape (n_params, n_folds) | |
Each row store the CV results for one parameter value. Note that it may | |
contain np.nan | |
""" | |
if ax is None: | |
ax = plt.gca() | |
# axes style | |
ax.grid() | |
# plot curve | |
if allow_missing: | |
y = np.nanmean(scores, axis=1) | |
err = np.nanstd(scores, axis=1) | |
else: | |
idx = ~np.isnan(scores).any(axis=1) | |
x = np.asarray(x)[idx] | |
y = np.mean(scores[idx], axis=1) | |
err = np.std(scores[idx], axis=1) | |
# style | |
fmt = "o-" | |
# plot curve | |
if show_error: | |
ax.errorbar(x, y, err, fmt=fmt) | |
else: | |
ax.plot(x, y, fmt=fmt) | |
# set labels and title | |
ax.set_xlabel(xlabel) | |
ax.set_ylabel(ylabel) | |
ax.set_title(title) | |
# set axis limit | |
if xscale is None: | |
ax.set_xlim(xmin=0, xmax=max(x)) | |
elif xscale == "log": | |
ax.set_xlim(xmin=min(x) / 2, xmax=max(x) * 2) | |
# set axis scale | |
if xscale is not None: | |
ax.set_xscale(xscale) | |
if yscale is not None: | |
ax.set_yscale(yscale) | |
def ternaryplot( | |
root_proba, | |
colors=["red", "green", "blue"], | |
markers=["s", "D", "o"], | |
fontsize=20, | |
): | |
assert (min(np.reshape(root_proba, np.size(root_proba))) >= 0) & ( | |
max(np.reshape(root_proba, np.size(root_proba))) <= 1 | |
) | |
import ternary | |
figure, tax = ternary.figure(scale=1) | |
figure.set_size_inches(10, 10) | |
tax.boundary(linewidth=2.0) | |
tax.gridlines(color="blue", multiple=0.05, linewidth=0.5) | |
tax.bottom_axis_label("Normal", fontsize=fontsize, color="brown") | |
tax.right_axis_label("early MCI", fontsize=fontsize, color="brown") | |
tax.left_axis_label("clinical MCI", fontsize=fontsize, color="brown") | |
# plot the prediction boundary | |
p = (1.0 / 3, 1.0 / 3, 1.0 / 3) | |
p1 = (0, 0.5, 0.5) | |
p2 = (0.5, 0, 0.5) | |
p3 = (0.5, 0.5, 0) | |
tax.line(p, p1, linestyle="--", color="brown", linewidth=3) | |
tax.line(p, p2, linestyle="--", color="brown", linewidth=3) | |
tax.line(p, p3, linestyle="--", color="brown", linewidth=3) | |
# plot scatter plot of the points | |
tax.scatter( | |
root_proba, s=1, linewidth=3.5, marker=markers[0], color=colors[0] | |
) | |
tax.ticks(axis="lbr", multiple=0.1, linewidth=1) | |
tax.clear_matplotlib_ticks() | |
tax.show() | |
return figure, tax | |
def barplot( | |
x, | |
hue, | |
hue_labels=None, | |
xlabel=None, | |
ylabel=None, | |
title=None, | |
ax=None, | |
cmap_name="Accent", | |
): | |
"""Plot the proportion of hue for each value of x. | |
Parameters | |
---------- | |
x : array-like | |
Group id. | |
hue : array-like | |
Hue id | |
""" | |
if ax is None: | |
ax = plt.gca() | |
n_group = len(set(x)) | |
n_hue = len(set(hue)) | |
hue2id = {h: i for i, h in enumerate(list(set(hue)))} | |
data = np.zeros([n_group, n_hue]) | |
for i in range(len(x)): | |
data[x[i], hue2id[hue[i]]] += 1 | |
data = data / data.sum(axis=1)[:, None] | |
# set the style | |
default_color_list = ["lightgreen", "dodgerblue", "orangered", "black"] | |
if n_hue <= len(default_color_list): | |
colors = default_color_list[:n_hue] | |
else: | |
colors = plt.cm.get_cmap(cmap_name)(np.linspace(0, 1, n_hue)) | |
width = 0.5 / n_hue | |
left = np.arange(n_group) | |
for i in range(n_hue): | |
_ = ax.bar(left + width * i, data[:, i], width, color=colors[i]) | |
ax.set_ylabel(ylabel) | |
ax.set_xlabel(xlabel) | |
ax.set_xticks(left + n_hue * width / 2) | |
ax.set_xticklabels(map(str, range(n_group))) | |
# add ticks | |
ax.yaxis.set_tick_params(left=True, right=True) | |
ax.xaxis.set_tick_params(top=True, bottom=True) | |
ax.set_xlim(xmin=-1 + width * n_hue, xmax=n_group) | |
# add legend | |
if hue_labels is not None and len(hue_labels) == n_hue: | |
ax.legend( | |
labels=hue_labels, loc="center left", bbox_to_anchor=(1, 0.5) | |
) | |
ax.set_title(title) | |
def networkplot( | |
weights, | |
labels=None, | |
max_node_size=3000, | |
min_node_size=100, | |
max_width=10, | |
min_width=1, | |
arrowsize=80, | |
colorbar=True, | |
x_margins=0, | |
scale=1.0, | |
ax=None, | |
): | |
import networkx as nx | |
assert weights.shape[0] == weights.shape[1] | |
assert labels is None or weights.shape[0] == len(labels) | |
if not ax: | |
ax = plt.figure().gca() | |
G = nx.DiGraph() | |
for i in range(weights.shape[0]): | |
for j in range(weights.shape[1]): | |
if weights[i][j] != 0: | |
G.add_edge(i, j, weight=weights[i][j]) | |
pos = nx.spring_layout(G, seed=0) | |
# node properties | |
node_size = np.sqrt(np.abs(weights).sum(0)) # weight sum of outgoing edges | |
node_size *= max_node_size / node_size.max() | |
node_size = np.maximum(node_size, min_node_size) | |
# node_size = max_node_size | |
# edge properties | |
edge_color = np.asarray([e[2]["weight"] for e in G.edges(data=True)]) | |
width = np.sqrt(np.abs(edge_color) / np.abs(edge_color).max()) * max_width | |
width = np.maximum(width, min_width) | |
if weights.min() >= 0: | |
edge_vmax = weights.max() | |
edge_vmin = 0 | |
cmap = plt.cm.Reds | |
elif weights.max() <= 0: | |
edge_vmax = 0 | |
edge_vmin = weights.min() | |
cmap = plt.cm.Blues_r | |
else: | |
edge_vmax = np.abs(weights).max() | |
edge_vmin = -edge_vmax | |
# edge_vmax = weights.max() | |
# edge_vmin = weights.min() | |
cmap = plt.cm.RdYlBu_r | |
nx.draw_networkx( | |
G, | |
pos, | |
node_size=node_size, | |
node_color="#008D0A", | |
labels={k: labels[k] for k in pos}, | |
ax=ax, | |
) | |
# change arrow size | |
arrowsize = width / width.max() * arrowsize | |
for i, e in enumerate(G.edges()): | |
nx.draw_networkx_edges( | |
G, | |
pos, | |
edgelist=[e], | |
arrowsize=arrowsize[i], | |
width=width[i], | |
edge_color=edge_color[i : i + 1], | |
edge_cmap=cmap, | |
edge_vmin=edge_vmin, | |
edge_vmax=edge_vmax, | |
node_size=node_size, | |
) | |
if colorbar: | |
sm = mpl.cm.ScalarMappable( | |
cmap=cmap, norm=plt.Normalize(vmin=edge_vmin, vmax=edge_vmax) | |
) | |
sm._A = [] | |
plt.colorbar(sm, ax=ax) | |
ax.margins(x=x_margins) | |
ax.axis("off") | |
return ax |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import random | |
import shutil | |
import torch | |
import torch.nn as nn | |
from torch.utils.data import DataLoader | |
class ResidualLayer(nn.Module): | |
def __init__( | |
self, in_features, out_features, hidden_size=0, activation=None | |
): | |
super().__init__() | |
hidden_size = hidden_size or in_features | |
self.net1 = nn.Sequential( | |
nn.Linear(in_features, hidden_size), | |
activation or nn.ReLU(), | |
nn.Linear(hidden_size, out_features), | |
) | |
if hidden_size == out_features: | |
self.net2 = lambda x: x | |
else: | |
self.net2 = nn.Linear(in_features, out_features, bias=False) | |
def forward(self, x): | |
return self.net1(x) + self.net2(x) | |
def save_checkpoint(state, output_folder, is_best, filename="checkpoint.tar"): | |
import torch | |
torch.save(state, os.path.join(output_folder, filename)) | |
if is_best: | |
shutil.copyfile( | |
os.path.join(output_folder, filename), | |
os.path.join(output_folder, "model_best.tar"), | |
) | |
def split_dataset(dataset, ratio: float): | |
n = len(dataset) | |
lengths = [int(n * ratio), n - int(n * ratio)] | |
return torch.utils.data.random_split(dataset, lengths) | |
def split_dataloader(dataloader, ratio: float): | |
dataset = dataloader.dataset | |
n = len(dataset) | |
lengths = [int(n * ratio), n - int(n * ratio)] | |
datasets = torch.utils.data.random_split(dataset, lengths) | |
copied_fields = ["batch_size", "num_workers", "collate_fn", "drop_last"] | |
dataloaders = [] | |
for d in datasets: | |
dataloaders.append( | |
DataLoader( | |
dataset=d, **{k: getattr(dataloader, k) for k in copied_fields} | |
) | |
) | |
return tuple(dataloaders) | |
class KeyBucketedBatchSampler(torch.utils.data.Sampler): | |
"""Pseduo bucketed batch sampler. | |
Sample in a way that | |
Args: | |
keys (List[int]): keys by which the same or nearby keys are allocated | |
in the same or nearby batches. | |
batch_size (int): | |
drop_last (bool, optional): Whether to drop the last incomplete batch. | |
Defaults to False. | |
shuffle_same_key (bool, optional): Whether to shuffle the instances of | |
the same keys. Defaults to False. | |
""" | |
def __init__( | |
self, keys, batch_size, drop_last=False, shuffle_same_key=True | |
): | |
self.keys = keys | |
self.batch_size = batch_size | |
self.drop_last = drop_last | |
self.shuffle_same_key = shuffle_same_key | |
# bucket sort; maintain random order inside each bucket | |
buckets = {} | |
for i, key in enumerate(self.keys): | |
if key not in buckets: | |
buckets[key] = [i] | |
else: | |
buckets[key].append(i) | |
self.buckets = buckets | |
def __iter__(self): | |
indices = [] | |
for key in sorted(self.buckets.keys()): | |
v = self.buckets[key] | |
if self.shuffle_same_key: | |
random.shuffle(v) | |
indices += v | |
index_batches = [] | |
for i in range(0, len(indices), self.batch_size): | |
j = min(i + self.batch_size, len(indices)) | |
index_batches.append(indices[i:j]) | |
del indices | |
if self.drop_last and len(index_batches[-1]) < self.batch_size: | |
index_batches = index_batches[:-1] | |
random.shuffle(index_batches) | |
for indices in index_batches: | |
yield indices | |
def __len__(self): | |
if self.drop_last: | |
return len(self.keys) // self.batch_size | |
else: | |
return (len(self.keys) + self.batch_size - 1) // self.batch_size | |
def convert_to_bucketed_dataloader( | |
dataloader: DataLoader, key_fn, shuffle_same_key=True | |
): | |
"""Convert a data loader to bucketed data loader with a given keys. | |
Args: | |
dataloader (DataLoader): | |
key_fn (Callable]): function to extract keys used for constructing | |
the bucketed data loader; should be of the same key as the | |
dataset. | |
shuffle_same_key (bool, optional): Whether to shuffle the instances of | |
the same keys. Defaults to False. | |
Returns: | |
DataLoader: | |
""" | |
assert ( | |
dataloader.batch_size is not None | |
), "The `batch_size` must be present for the input dataloader" | |
dataset = dataloader.dataset | |
keys = [key_fn(dataset[i]) for i in range(len(dataset))] | |
batch_sampler = KeyBucketedBatchSampler( | |
keys, | |
batch_size=dataloader.batch_size, | |
drop_last=dataloader.drop_last, | |
shuffle_same_key=shuffle_same_key, | |
) | |
return DataLoader( | |
dataset, batch_sampler=batch_sampler, collate_fn=dataloader.collate_fn | |
) | |
def generate_sequence_mask(lengths, device=None): | |
""" | |
Args: | |
lengths (LongTensor): 1-D | |
Returns: | |
BoolTensor: [description] | |
""" | |
index = torch.arange(lengths.max(), device=device or lengths.device) | |
return index.unsqueeze(0) < lengths.unsqueeze(1) | |
def set_eval_mode(module, root=True): | |
if root: | |
module.train() | |
name = module.__class__.__name__ | |
if "Dropout" in name or "BatchNorm" in name: | |
module.training = False | |
for child_module in module.children(): | |
set_eval_mode(child_module, False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment