Certain backends have aditional requirements for the weights, and it would be good if we could somehow transform the weights into this format once, and reuse it accross many iterations.
Currently we have two modules that could already benefit from these changes:
- cuDNN RNN could format the weights into a single contiguous block of memory
- BatchNorm could maintain fp32 running averages to stabilize fp16 training