Last active
April 11, 2021 23:29
-
-
Save ericspod/f4da372d22cc8da420ee74b8968303cd to your computer and use it in GitHub Desktop.
Base Engine, Trainer, and Evaluator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import warnings | |
import threading | |
import numpy as np | |
from ignite.engine.engine import Engine, Events | |
def ensure_tuple(vals): | |
""" | |
Returns a tuple containing just `vals` if it is not a list or tuple, or `vals` converted to a tuple otherwise. | |
""" | |
if not isinstance(vals, (list, tuple)): | |
vals = (vals,) | |
return tuple(vals) | |
class BaseEngine(Engine): | |
""" | |
Base training/evaluating engine inheriting from Ignite's Engine. This manages a single network's train/eval/infer | |
process. Setups with multiple networks should have multiple instances of subtypes of this class. | |
Attributes | |
---------- | |
net : torch.nn.Module | |
The network to train or evaluate | |
loss : torch.nn.modules.loss._Loss | |
Loss object | |
opt : torch.optim.Optimizer, optional | |
Optimizer object for training or None | |
device_ids : int or tuple of int, optional | |
CUDA device ID numbers stating which devices to compute on, empty sequence for CPU computation | |
useCUDA : bool | |
True if CUDA is to be used (ie. `device_ids` has valid indentifiers in it) and is available | |
device : torch.device | |
The device this object uses to create tensors | |
nonBlocking : bool | |
Determines if tensors are created as blocking or not, default is False | |
lock : threading.RLock | |
Anytime `net` is used for training or inference this lock should be used to ensure exclusive access | |
Attributes: | |
net (torch.nn.Module): The network to train or evaluate | |
loss (torch.nn.modules.loss._Loss): Loss object | |
opt (torch.optim.Optimizer): Optimizer object for training or None | |
device_ids (tuple of int): CUDA device ID numbers stating which devices to compute on, empty for CPU computation | |
useCUDA (bool): True if CUDA is to be used (ie. `device_ids` has valid indentifiers in it) and is available | |
device (torch.device): The device this object uses to create tensors | |
nonBlocking (bool): Determines if tensors are created as blocking or not, default is False | |
lock (threading.RLock): Anytime `net` is used for training or inference, or anytime tensors it relies on are | |
accessed, this lock should be used to ensure exclusive access | |
""" | |
def __init__(self, net, loss, opt=None, device_ids=[0], step_func=None): | |
""" | |
Initialize the engine with network, loss, and optimizer provided. | |
Parameters | |
---------- | |
net : torch.nn.Module | |
The network to train or evaluate | |
loss : torch.nn.modules.loss._Loss | |
Loss object | |
opt : torch.optim.Optimizer, optional | |
Optimizer object for training or None | |
device_ids : tuple of int, optional | |
CUDA device ID numbers stating which devices to compute on, empty sequence for CPU computation | |
step_func : callable, optional | |
Callable defining the training/evaluation iteration step behaviour, passed to the super-constructor call, if | |
None `self.step` is passed instead allowing engine behaviour to be determined through inheritance | |
""" | |
self.device_ids = list(device_ids) | |
self.useCuda = len(self.device_ids) > 0 and torch.cuda.is_available | |
self.device = torch.device("cuda:%i"%(self.device_ids[0]) if self.useCuda else "cpu") | |
self.nonBlocking = False | |
self.lock = threading.RLock() | |
self.net = net.to(self.device) | |
self.loss = loss.to(self.device) | |
self.opt = opt | |
step = self.step | |
if step_func is not None: | |
step = step_func | |
super().__init__(step) | |
def to_tensor(self, arr): | |
""" | |
Convert the array or sequence of arrays to tensor(s) if necessary, returning a tensor or a tuple thereof. | |
Parameters | |
---------- | |
arr : iterable | |
A list, tuple, dictionary, etc. containing numpy arrays to convert to tensors | |
Returns | |
------- | |
iterable | |
A list, tuple, dictionary, etc. of tensors situated on `self.device` | |
""" | |
if isinstance(arr, np.ndarray): | |
return torch.from_numpy(arr).to(device=self.device, non_blocking=self.nonBlocking) | |
elif isinstance(arr, (list, tuple)): | |
return tuple(map(self.to_tensor, arr)) | |
elif isinstance(arr, dict): | |
return {k: self.to_tensor(v) for k, v in arr.items()} | |
else: | |
return arr | |
def to_numpy(self, tensor): | |
""" | |
Convert the tensor or sequence of tensors to numpy array(s) if necessary, returning an array or a tuple thereof. | |
Parameters | |
---------- | |
arr : iterable | |
A list, tuple, dictionary, etc. containing tensors to convert to numpy arrays | |
Returns | |
------- | |
iterable | |
A list, tuple, dictionary, etc. of numpy arrays | |
""" | |
if isinstance(tensor,np.ndarray): | |
return tensor | |
elif isinstance(tensor, (list, tuple)): | |
return tuple(map(self.to_numpy, tensor)) | |
elif isinstance(tensor, dict): | |
return {k: self.to_numpy(v) for k, v in tensor.items()} | |
else: | |
return tensor.to("cpu").data.numpy() | |
def set_requires_grad(self, grad=True): | |
""" | |
Set `requires_grad` for every parameter of `self.net` to `grad`. | |
Parameters | |
---------- | |
grad : bool | |
Value to set each `requires_grad` to | |
""" | |
for p in self.net.parameters(): | |
p.requires_grad = grad | |
def net_forward(self, inputs): | |
""" | |
Apply the values from `inputs` to the network and return the results. If multiple devices are being computed | |
on, `torch.nn.parallel.data_parallel` is used to broadcast the values to each per-device replica `self.net`. | |
Parameters | |
---------- | |
inputs : list or tuple of tensors | |
The input parameters for `self.net` | |
Returns | |
------- | |
tuple of tensors | |
The output(s) of the network contained in a tuple, if a single tensor is produced by the network this is | |
placed in a single-value tuple | |
""" | |
# TODO: add event for this method? | |
if self.useCuda and len(self.device_ids) > 1: | |
result = torch.nn.parallel.data_parallel(self.net, tuple(inputs), self.device_ids) | |
else: | |
result = self.net(*inputs) | |
return ensure_tuple(result) | |
def loss_forward(self, predictions, ground): | |
""" | |
Compute the loss value using `self.loss` with input expanded from `ground` and `predictions`. | |
Parameters | |
---------- | |
predictions : list or tuple of tensors | |
The prediction values for `self.loss` which will be expanded as the first series of positional arguments | |
ground : list or tuple of tensors | |
The ground truth values for `self.loss` which will be expanded as the second series of positional arguments | |
Returns | |
------- | |
tuple of tensors | |
The output(s) of the loss object contained in a tuple, if a single tensor is produced by the loss function | |
this is placed in a single-value tuple | |
""" | |
# TODO: add event for this method? | |
args = tuple(predictions) + tuple(ground) | |
return self.loss(*args) | |
def infer(self, infer_src): | |
""" | |
Apply inference to the batches taken from `infer_src`, returning a list of results from the network. The input | |
source is expected to be finite otherwise this method will not return. The `self.net_forward` method is called | |
with every item in a generated batch passed as the argument tuple. | |
Parameters | |
---------- | |
infer_src : iterable | |
Iterable yielding tuples of numpy arrays containing the inputs to `self.net_forward` | |
Returns | |
------- | |
list of tuples | |
Returns a list of results from applying each tuple from the source to the network | |
""" | |
# TODO: add event for this method? | |
return list(self.infer_gen(infer_src)) | |
def infer_gen(self,infer_src): | |
""" | |
Apply inference to the batches taken from `infer_src`, yielding the results from the network. The input source | |
is expected to be finite otherwise this generator will not return and will rely on the consumer to stop the | |
iteration. `self.net_forward` is called with the whole of a generated batch passed as the argument tuple. | |
Parameters | |
---------- | |
infer_src : iterable | |
Iterable yielding tuples of numpy arrays containing the inputs to `self.net_forward` | |
Yields | |
------ | |
tuple | |
Yields the result from applying each tuple from the source to the network | |
""" | |
# TODO: add event for this method? | |
for batch in infer_src: | |
with self.lock, torch.no_grad(): | |
self.net.eval() | |
net_inputs = self.to_tensor(ensure_tuple(batch)) | |
net_outputs = self.net_forward(net_inputs) | |
yield self.to_numpy(net_outputs) | |
def step(self, engine, batch): | |
""" | |
Train/eval/infer step function, accepts the engine (which is `self`) and current batch as input. By default | |
this method only asserts that `engine` is `self`. | |
Parameters | |
---------- | |
engine : BaseEngine | |
This is the same object as `self` | |
batch : tuple of np.ndarray | |
The batch tuple for the current iteration | |
Returns | |
------- | |
None | |
The loss result should be returned in overrides | |
""" | |
assert engine is self | |
class Trainer(BaseEngine): | |
""" | |
The basic engine subtype for training a network. The given `self.step` method is for training a network accepting | |
inputs from a batch, the results from which are passed to a loss function whose output can be back-propagated. | |
During training the converted batch is stored in `state.net_inputs`, network outputs in `state.net_outputs`, and loss | |
function outputs in `state.loss_outputs`. These members of the state object can be accessed to inspect the training | |
parameters. | |
Attributes | |
---------- | |
net_input_indices : tuple of int | |
Indices of network input tensors in each batch | |
loss_pred_indices : tuple of int | |
Indices of the ground truth tensors in each batch | |
loss_ground_indices : tuple of int | |
Indices of the prediction tensors in each network output | |
""" | |
def __init__(self, net, loss, opt=None, device_ids=[0], net_input_indices=[0], | |
loss_pred_indices=[0], loss_ground_indices=[-1], step_func=None): | |
""" | |
Create the trainer object with the given network, loss function, optimizer, and parameters stating which | |
members of each batch tuple or network output are inputs for the network or loss function, and which are | |
ground truth values. Changing these values allows various configurations of training a network whose | |
output is passed to a loss function along with ground truth values, eg. a simple supervised environment. | |
Parameters | |
---------- | |
net : torch.nn.Module | |
The network to train or evaluate | |
loss : torch.nn.modules.loss._Loss | |
Loss object | |
opt : torch.optim.Optimizer, optional | |
Optimizer object for training, if None Adam is used instead with default parameters | |
device_ids : int or tuple of int, optional | |
CUDA device ID numbers stating which devices to compute on, empty sequence for CPU computation | |
net_input_indices : tuple of int, optional | |
Indices of network input tensors in each batch | |
loss_pred_indices : tuple of int, optional | |
Indices of the ground truth tensors in each batch | |
loss_ground_indices : tuple of int, optional | |
Indices of the prediction tensors in each network output | |
step_func : callable, optional | |
Callable defining the training/evaluation iteration step behaviour, passed to the super-constructor call, if | |
None `self.step` is passed instead allowing engine behaviour to be determined through inheritance | |
""" | |
if opt is not None: | |
opt = torch.optim.Adam(net.parameters()) | |
super().__init__(net, loss, opt, device_ids, step_func) | |
self.net_input_indices = tuple(net_input_indices) | |
self.loss_pred_indices = tuple(loss_pred_indices) | |
self.loss_ground_indices = tuple(loss_ground_indices) | |
def step(self, engine, batch): | |
""" | |
The default network training loop for any training process with paired input and ground truth datasets. This | |
trains the network for the number of substeps given in `self.state.num_substeps`, that is the network is trained | |
for that many steps with the same input data and ground truths derived from `batch`. The indexing member | |
`net_input_indices` is used to determine which arrays from `batch` are inputs to the network, `netPredIndices` to | |
determine which outputs from the network are to be passed to the loss function as prediction values, and | |
`loss_ground_indices` to determine which arrays from `batch` are ground truth. | |
Parameters | |
---------- | |
engine : BaseEngine | |
This is the same object as `self` | |
batch : tuple of np.ndarray | |
The batch tuple for the current iteration | |
Returns | |
------- | |
float | |
The result from the loss function | |
""" | |
with self.lock: | |
self.state.net_inputs = self.to_tensor(ensure_tuple(batch)) | |
inputs = [self.state.net_inputs[i] for i in self.net_input_indices] | |
ground = [self.state.net_inputs[i] for i in self.loss_ground_indices] | |
for substep in range(self.state.num_substeps): | |
self.net.train() | |
self.opt.zero_grad() | |
self.state.net_outputs = self.net_forward(inputs) | |
pred = [self.state.net_outputs[i] for i in self.loss_pred_indices] | |
self.state.loss_outputs = self.loss_forward(pred, ground) | |
self.state.loss_outputs.backward() | |
self.opt.step() | |
return self.state.loss_outputs.item() | |
def train(self, src, max_iterations=None, max_epochs=1, num_substeps=1): | |
""" | |
Train the network with the given data source, maximum iteration, epoch, and substep counts. | |
Parameters | |
---------- | |
src : iterable | |
Iterable data source yielding batches of data | |
max_iterations : int, optional | |
Number of iterations to train for each epoch, if not None iterations are performed until `src` is exhausted | |
for each epoch | |
max_epochs : int, optional | |
Number of epochs (sets of iterations) to train | |
maxSubsteps : int, optional | |
Number of substeps to train, default of 1 implies the commonplace behaviour of training only once per batch | |
Returns | |
------- | |
ignite.engine.engine.State | |
The state object from the training run | |
""" | |
def _set_state(_): | |
self.state.num_substeps = num_substeps | |
with self.add_event_handler(Events.STARTED,_set_state): | |
return self.run(src, max_epochs,max_iterations) | |
def get_evaluator(self, loss=None, step_func=None): | |
""" | |
Return an Evaluator object referencing this object's network and loss objects, and configured with the same | |
devices and index tuples. | |
Parameters | |
---------- | |
step_func: callable, optional | |
Evaluation step function to pass to Evaluator object, if None the default `step` method is used | |
Returns | |
------- | |
Evaluator | |
Evaluation object for this object's network | |
""" | |
return Evaluator(self.net, loss or self.loss, self.device_ids, self.net_input_indices, | |
self.loss_pred_indices, self.loss_ground_indices,step_func) | |
class Evaluator(BaseEngine): | |
""" | |
Engine subclass for evaluating a network for validation or other analysis. The default `step` method implements a | |
simple evaluation step which does a forward pass on the network and loss function, and returns the loss value. | |
During evaluation the converted batch is stored in `state.net_inputs`, network outputs in `state.net_outputs`, and | |
loss function outputs in `state.loss_outputs`. These members of the state object can be accessed to inspect the | |
evaluation parameters. | |
Attributes | |
---------- | |
net_input_indices : tuple of int | |
Indices of network input tensors in each batch | |
loss_pred_indices : tuple of int | |
Indices of the ground truth tensors in each batch | |
loss_ground_indices : tuple of int | |
Indices of the prediction tensors in each network output | |
""" | |
def __init__(self, net, loss, device_ids=[0], net_input_indices=[0], | |
loss_pred_indices=[0], loss_ground_indices=[-1], step_func=None): | |
""" | |
Create the evaluator object with the given network, loss function, and parameters stating which members of each | |
batch tuple or network output are inputs for the network or loss function. Changing these values allows various | |
configurations of evaluating a network whose output is passed to a loss function with ground truth values. | |
Parameters | |
---------- | |
net : torch.nn.Module | |
The network to train or evaluate | |
loss : torch.nn.modules.loss._Loss | |
Loss object | |
device_ids : int or tuple of int, optional | |
CUDA device ID numbers stating which devices to compute on, empty sequence for CPU computation | |
net_input_indices : tuple of int, optional | |
Indices of network input tensors in each batch | |
loss_pred_indices : tuple of int, optional | |
Indices of the ground truth tensors in each batch | |
loss_ground_indices : tuple of int, optional | |
Indices of the prediction tensors in each network output | |
step_func : callable, optional | |
Callable defining the training/evaluation iteration step behaviour, passed to the super-constructor call, if | |
None `self.step` is passed instead allowing engine behaviour to be determined through inheritance | |
""" | |
super().__init__(net, loss, None, device_ids, step_func) | |
self.net_input_indices = net_input_indices # indices of the network input tensors in each batch | |
self.loss_pred_indices = loss_pred_indices # indices of the ground truth tensors in each batch | |
self.loss_ground_indices = loss_ground_indices # indices of the prediction tensors in each network output | |
def step(self, engine, batch): | |
""" | |
The default network evaluation loop for any process with paired input and ground truth datasets. This | |
evaluates the network and loss function, returning the loss result. The indexing member `net_input_indices` is | |
used to determine which arrays from `batch` are inputs to the network, `netPredIndices` to determine which | |
outputs from the network are to be passed to the loss function as prediction values, and`loss_ground_indices` | |
to determine which arrays from `batch` are ground truth. | |
Parameters | |
---------- | |
engine : BaseEngine | |
This is the same object as `self` | |
batch : tuple of np.ndarray | |
The batch tuple for the current iteration | |
Returns | |
------- | |
float | |
The result the loss function | |
""" | |
with self.lock, torch.no_grad(): | |
self.state.net_inputs = self.to_tensor(ensure_tuple(batch)) | |
inputs = [self.state.net_inputs[i] for i in self.net_input_indices] | |
ground = [self.state.net_inputs[i] for i in self.loss_ground_indices] | |
self.net.eval() | |
self.state.net_outputs = self.net_forward(inputs) | |
pred = [self.state.net_outputs[i] for i in self.loss_pred_indices] | |
self.state.loss_outputs = self.loss_forward(pred, ground) | |
return self.state.loss_outputs.item() | |
def evaluate(self, src, max_iterations=None): | |
""" | |
Evaluates the network for each batch in `src`, which must be finite or `max_iterations` must be a positive int. | |
For each batch, the returned list will contain a pair storing the network output and loss output tensors. | |
Parameters | |
---------- | |
src : iterable | |
Input batch source | |
max_iterations : int, optional | |
Maximum number of evaluation iterations, if not None then iterations are performed until `src` is exhausted | |
Returns | |
------- | |
list of tuples | |
List of each (network output, loss output) pairs for each batch from `src` | |
""" | |
results = [] | |
def _collect_results(_): | |
out = self.state.net_outputs | |
loss = self.state.loss_outputs | |
results.append(self.to_numpy((out, loss))) | |
with self.add_event_handler(Events.ITERATION_COMPLETED,_collect_results): | |
self.run(src, 1) | |
return results | |
def evaluate_gen(self, src, max_iterations=None): | |
""" | |
Evaluates the network for each batch in `src`, which must be finite or `max_iterations` must be a positive int. | |
For each batch, this generator yields a pair storing the network output and loss output tensors. | |
Parameters | |
---------- | |
src : iterable | |
Input batch source | |
max_iterations : int, optional | |
Maximum number of evaluation iterations, if not None then iterations are performed until `src` is exhausted | |
Yields | |
------ | |
tuple | |
The (network output, loss output) pair for each batch from `src` | |
""" | |
for batch in src: | |
self.run([batch], 1) | |
out = self.state.net_outputs | |
loss = self.state.loss_outputs | |
yield self.to_numpy((out, loss)) | |
def evaluate_mean_loss(self, src, max_iterations=0): | |
""" | |
Calculate the mean loss over all of the inputs in `src`. | |
""" | |
total_size=0 | |
total_loss=0 | |
for output, eloss in self.evaluate_gen(src, max_iterations): | |
batch_size=output[0].shape[0] | |
total_loss+=eloss*batch_size | |
total_size+=batch_size | |
return total_loss/total_size | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment