Skip to content

Instantly share code, notes, and snippets.

@madebyollin
Created January 12, 2025 05:35
Show Gist options
  • Save madebyollin/c20a816f01d9146d6a66a0bf840ceeb0 to your computer and use it in GitHub Desktop.
Save madebyollin/c20a816f01d9146d6a66a0bf840ceeb0 to your computer and use it in GitHub Desktop.
Useful neural network training concepts (narrow usage, broad applicability)

These useful concepts show up in specific areas of NN-training literature but can be applied pretty broadly.

  1. Non-leaky augmentations: you can add arbitrary augmentations during training, without substantially biasing in-domain performance, by adding a secondary input that tells the network which augmentations were used. This technique shows up the Karras et al image generation papers (ex. https://arxiv.org/pdf/2206.00364) but it's applicable whenever you want good performance on limited data.
  2. Batch-stratified sampling: rather than generating per-sample random numbers with e.g. torch.rand(batch_size), you can use th.randperm(batch_size).add_(th.rand(batch_size)).div_(batch_size) instead, which has the same distribution but lower variance, and therefore trains more stably. This shows up in k-diffusion https://github.com/crowsonkb/k-diffusion/commit/a2b7b5f1ea0d3711a06661ca9e41b4e6089e5707, but it's applicable whenever you're randomizing data across the batch axis.
  3. Replay buffers: when your data samples are non-i.i.d., rather than always training on the next batch of data, you can maintain an in-memory pool of data and periodically sample from / add to it. The larger your pool, the closer your data will be to i.i.d, and the more stably your network will train. This shows up in RL literature https://pytorch.org/rl/main/tutorials/rb_tutorial.html, but it's applicable whenever you have a non-stationary input distribution.
  4. Patchwise training: if your network has limited spatial receptive field, training on a large batch of smallish patches will give you stabler / lower-variance training than training on a small batch of big patches (and you can always finetune on large images at the end). This shows up in autoencoder training literature https://arxiv.org/pdf/2012.09841, but it's applicable whenever you're training a spatially-equivariant network.
  5. Fourier encoding: if your network needs to consume an unbounded value, or a bounded value with many bits of precision, you can pre-encode the value as a vector of sin(x*k) features with varying k. This allows the network to be equally sensitive to both small or large changes in the value. Fourier encoding shows up in diffusion models, https://www.vincentsitzmann.com/siren/, https://www.matthewtancik.com/nerf, etc., but it's applicable whenever you need to make a network react to small changes in input signals.
  6. Small Noise augmentation: if you want to limit the network's sensitivity to small changes of a particular input signal, you can add small-std gaussian noise to that signal during training, to guarantee that small changes in the signal are ignored. This shows up in cascaded upsampling / autoregression literature https://arxiv.org/pdf/2106.15282, but it's applicable any time you want to make a network ignore small changes in input signals.
  7. Dropout augmentation: if you have an input signal C that the network gets during training but may or may not get during inference/testing, you can drop C out during training (zero it randomly, possibly with an extra indicator bit) so that the same network learns to predict results with and without C. This technique shows up in classifier-free-guided diffusion models https://arxiv.org/pdf/2207.12598, but it's applicable any time that you want a model to support an "optional" input signal.
  8. Periodic weight averaging: you can get a better network checkpoint for free by averaging N different network checkpoints taken from multiple partially-converged training iterations. This technique allows you to keep training with constant lr and (mostly) still get a "cooled" checkpoint out at any time. This technique shows up in diffusion literature all the time (though see https://arxiv.org/pdf/2405.18392v2 fig 8 or https://arxiv.org/pdf/2312.02696 fig 14 for why lr cooldown might still be worth doing).
  9. Feature matching loss: if you are training a network network_1 that produces a signal x_pred that's intended to be consumed by another network network_2, in addition to penalizing |x_pred - x|, you can also penalize differences in the intermediate features of the consumer network |network_w[:n](x_pred).mean(axes) - network_w[:n](x).mean(axes)|. This kind of additional loss ensures that your x_pred signals will be "similar" to the ground-truth x signals according to whatever criteria network_2 cares about, rather than just similar according to L2/L1 loss. This technique shows up in GAN training literature, but it's applicable any time you're training networks to produce outputs for other networks.

TODO: find more of these

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment