These useful concepts show up in specific areas of NN-training literature but can be applied pretty broadly.
- Non-leaky augmentations: you can add arbitrary augmentations during training, without substantially biasing in-domain performance, by adding a secondary input that tells the network which augmentations were used. This technique shows up the Karras et al image generation papers (ex. https://arxiv.org/pdf/2206.00364) but it's applicable whenever you want good performance on limited data.
- Batch-stratified sampling: rather than generating per-sample random numbers with e.g.
torch.rand(batch_size)
, you can useth.randperm(batch_size).add_(th.rand(batch_size)).div_(batch_size)
instead, which has the same distribution but lower variance, and therefore trains more stably. This shows up in k-diffusion https://github.com/crowsonkb/k-diffusion/commit/a2b7b5f1ea0d3711a06661ca9e41b4e6089e5707, but it's applicable whenever you're randomizing data across the batch axis. - Replay buffers: when your data samples are non-i.i.d., rather than always training on the next batch of data, you can maintain an in-memory pool of data and periodically sample from / add to it. The larger your pool, the closer your data will be to i.i.d, and the more stably your network will train. This shows up in RL literature https://pytorch.org/rl/main/tutorials/rb_tutorial.html, but it's applicable whenever you have a non-stationary input distribution.
- Patchwise training: if your network has limited spatial receptive field, training on a large batch of smallish patches will give you stabler / lower-variance training than training on a small batch of big patches (and you can always finetune on large images at the end). This shows up in autoencoder training literature https://arxiv.org/pdf/2012.09841, but it's applicable whenever you're training a spatially-equivariant network.
- Fourier encoding: if your network needs to consume an unbounded value, or a bounded value with many bits of precision, you can pre-encode the value as a vector of sin(x*k) features with varying k. This allows the network to be equally sensitive to both small or large changes in the value. Fourier encoding shows up in diffusion models, https://www.vincentsitzmann.com/siren/, https://www.matthewtancik.com/nerf, etc., but it's applicable whenever you need to make a network react to small changes in input signals.
- Small Noise augmentation: if you want to limit the network's sensitivity to small changes of a particular input signal, you can add small-std gaussian noise to that signal during training, to guarantee that small changes in the signal are ignored. This shows up in cascaded upsampling / autoregression literature https://arxiv.org/pdf/2106.15282, but it's applicable any time you want to make a network ignore small changes in input signals.
- Dropout augmentation: if you have an input signal
C
that the network gets during training but may or may not get during inference/testing, you can dropC
out during training (zero it randomly, possibly with an extra indicator bit) so that the same network learns to predict results with and withoutC
. This technique shows up in classifier-free-guided diffusion models https://arxiv.org/pdf/2207.12598, but it's applicable any time that you want a model to support an "optional" input signal. - Periodic weight averaging: you can get a better network checkpoint for free by averaging N different network checkpoints taken from multiple partially-converged training iterations. This technique allows you to keep training with constant lr and (mostly) still get a "cooled" checkpoint out at any time. This technique shows up in diffusion literature all the time (though see https://arxiv.org/pdf/2405.18392v2 fig 8 or https://arxiv.org/pdf/2312.02696 fig 14 for why lr cooldown might still be worth doing).
- Feature matching loss: if you are training a network
network_1
that produces a signalx_pred
that's intended to be consumed by another networknetwork_2
, in addition to penalizing|x_pred - x|
, you can also penalize differences in the intermediate features of the consumer network|network_w[:n](x_pred).mean(axes) - network_w[:n](x).mean(axes)|
. This kind of additional loss ensures that yourx_pred
signals will be "similar" to the ground-truthx
signals according to whatever criterianetwork_2
cares about, rather than just similar according to L2/L1 loss. This technique shows up in GAN training literature, but it's applicable any time you're training networks to produce outputs for other networks.
TODO: find more of these