This article is about how multiple multi-modal images are fused as input to a 2D or 3D U-Net by the example of nnUNet [1].
The following code snippet shows what happens in nnUNet at the first conv layer when a multi-modal input of 4 modalities is given. Note the input shape represents a batch size of 2, 4 image modalities (say 4 MRI sequences) and the remaining three dimensions represents the MRI volume of 48 slices of size 160x256 pixels:
>>> from torch import nn
>>> m = nn.Conv3d(in_channels=4, out_channels=32, kernel_size=[1,3,3], stride=[1,1,1], padding=[0,1,1])
>>> input = torch.randn(2, 4, 48, 160, 256)
>>> output = m(input)