This article is about how multiple multi-modal images are fused as input to a 2D or 3D U-Net by the example of nnUNet [1].
The following code snippet shows what happens in nnUNet at the first conv layer when a multi-modal input of 4 modalities is given. Note the input shape represents a batch size of 2, 4 image modalities (say 4 MRI sequences) and the remaining three dimensions represents the MRI volume of 48 slices of size 160x256 pixels:
>>> from torch import nn
>>> m = nn.Conv3d(in_channels=4, out_channels=32, kernel_size=[1,3,3], stride=[1,1,1], padding=[0,1,1])
>>> input = torch.randn(2, 4, 48, 160, 256)
>>> output = m(input)
>>> output.shape
torch.Size([2, 32, 48, 160, 256])
>>>
Thus, the input to the network is of dimension 5:
(batch_size, channels, z, y, x)
Behind each of the 4 input channels, therefore, is a 3D volume of dimension 48x160x256 as opposed to a 2D image (of say 160x256).
The input shape for a single 3D modality will be like (2, 1, 48, 160, 256) where 1 refers to one channel.
Note that the input above is gray-scale not color. What about a 3D color volume input?
Multi-modal images are treated as color channels. BraTS, which comes with T1, T1c, T2 and Flair images for each training case will thus for example have 4 input channels [2].
[1] https://github.com/MIC-DKFZ/nnUNet
[2] https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/common_questions.md