Last active
April 20, 2021 11:40
-
-
Save casperdcl/e0c2f5bff1d731d353c07b5f5f6226ee to your computer and use it in GitHub Desktop.
Radiol. 290(3) 649-656
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Residual U-net implementation based on [1]. | |
Usage: | |
>>> from chen2019 import network | |
>>> # input_data.shape == (num_slices, slice_height, slice_width, num_channels) | |
>>> model = network(input_data.shape[1:]) | |
>>> model.fit(input_data, output_date, epochs=100, batch_size=input_data.shape[0] // 4, ...) | |
TODO: | |
- do we need more epochs? | |
- do we use triplets of convolutional blocks rather than pairs? | |
- do we use a different batch size? | |
[1] K. T. Chen et al. 2019 Radiol. 290(3) 649-656 | |
"Ultra-Low-Dose 18F-Florbetaben Amyloid PET Imaging Using Deep Learning with Multi-Contrast MRI Inputs" | |
""" | |
from tensorflow import keras | |
__author__ = "Casper da Costa-Luis <[email protected]>" | |
def network(input_shape, residual_input_channel=1, lr=2e-4, dtype="float32"): | |
""" | |
residual_input_channel : input channel index to use for residual addition | |
""" | |
x = inputs = keras.layers.Input(input_shape, dtype=dtype) | |
def block(x, filters): | |
x = keras.layers.Conv2D(filters, 3, padding="same", use_bias=False, dtype=dtype)(x) | |
x = keras.layers.BatchNormalization(dtype=dtype)(x) | |
x = keras.layers.ReLU(dtype=dtype)(x) | |
return x | |
# U-net | |
filters = [16, 32, 64, 128] | |
## encode | |
convs = [] | |
for i in filters[:-1]: | |
x = block(x, i) | |
x = block(x, i) | |
# TODO: do we need a third block? | |
convs.append(x) | |
x = keras.layers.MaxPool2D(dtype=dtype, padding="same")(x) | |
x = block(x, filters[-1]) | |
x = block(x, filters[-1]) | |
## decode | |
for i in filters[:-1][::-1]: | |
x = keras.layers.UpSampling2D(interpolation="bilinear", dtype=dtype)(x) | |
x = keras.layers.Concatenate()([x, convs.pop()]) | |
x = block(x, i) | |
x = block(x, i) | |
x = keras.layers.Conv2D(1, 1, padding="same", dtype=dtype, name="residual")(x) | |
x = keras.layers.Add(name="generated")( | |
[inputs[..., residual_input_channel : residual_input_channel + 1], x] | |
) | |
model = keras.Model(inputs=inputs, outputs=x) | |
opt = keras.optimizers.Adam(lr) | |
model.compile(opt, loss="mae") | |
model.summary() | |
return model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Incorporated into NiftyML (https://www.nifty.ml/en/latest/examples/)