Last active
July 29, 2020 00:35
-
-
Save zaccharieramzi/e1bec4fb51b7bf1703a7a7d8d52cf999 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Benchmarking keras vs pytorch on IFFT2D" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"I am currently working on MRI reconstruction deep learning models. A lot of these models involve a heavy use of FFT2D and IFFT2D allowing to correct the artifacts both in image space and k-space alternatively.\n", | |
"\n", | |
"Therefore, using a deep learning framework which allows for an efficient and easy-to-use implementation of the IFFT2D/FFT2D pair. I was using `keras` with a `tensorflow` backend originally because I was familiar with it and because I thought there was no reason for the Fourier transform operations to be inefficient.\n", | |
"\n", | |
"It turns out that the Fourier transform (at least in 2D) is particularly inefficient in `tensorflow`. You can follow this [Github issue](https://github.com/tensorflow/tensorflow/issues/6541) to learn more (it's closed, but the discussion is actually going on).\n", | |
"\n", | |
"I thus decided to see whether `pytorch` would perform better. The answer, according to this simple benchmark, is that it is almost 40 times faster than `keras` with a `tensorflow` backend." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# this just to make sure we are using only on CPU\n", | |
"import os\n", | |
"os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"-1\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# this to get rid of tensorflow deprecation warnings\n", | |
"import warnings\n", | |
"warnings.filterwarnings('ignore')\n", | |
"import logging\n", | |
"logging.getLogger('tensorflow').disabled = True" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Using TensorFlow backend.\n" | |
] | |
} | |
], | |
"source": [ | |
"from keras.layers import Input, Lambda, Conv2D, concatenate\n", | |
"from keras.models import Model\n", | |
"import numpy as np\n", | |
"import tensorflow as tf\n", | |
"from tensorflow.signal import ifft2d\n", | |
"import torch\n", | |
"from torch import nn" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'1.3.0.dev20190828'" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.__version__" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'1.14.0'" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"tf.__version__" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'2.2.4'" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import keras\n", | |
"keras.__version__" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Fake data creation" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"dtype = torch.float32\n", | |
"device = torch.device(\"cpu\")\n", | |
"\n", | |
"N, im_size = 35, 320\n", | |
"# Create random input and output data\n", | |
"x = torch.randn(N, im_size, im_size, 2, device=device, dtype=dtype)\n", | |
"y = torch.randn(N, im_size, im_size, device=device, dtype=dtype)\n", | |
"\n", | |
"# in numpy\n", | |
"x_np = x.numpy()\n", | |
"x_np = x_np[...,0] + 1j * x_np[..., 1]\n", | |
"x_np = x_np[..., None]\n", | |
"y_np = y.numpy()\n", | |
"y_np = y_np[..., None]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Pytorch" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Model definition and creation" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class ConvInvNet(torch.nn.Module):\n", | |
" def __init__(self):\n", | |
" super(ConvInvNet, self).__init__()\n", | |
" self.conv = nn.Conv2d(2, 2, kernel_size=3, padding=1)\n", | |
" \n", | |
" def forward(self, x):\n", | |
" y = torch.ifft(x, 2)\n", | |
" # this because pytorch doesn't support NHWC\n", | |
" y = y.permute(0, 3, 1, 2)\n", | |
" y = self.conv(y)\n", | |
" y = y.permute(0, 2, 3, 1)\n", | |
" y = (y ** 2).sum(dim=-1).sqrt()\n", | |
" return y\n", | |
" \n", | |
"model_conv = ConvInvNet()\n", | |
"criterion = torch.nn.MSELoss(reduction='sum')\n", | |
"optimizer = torch.optim.SGD(model_conv.parameters(), lr=1e-4)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Computations" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 1.31 s, sys: 577 ms, total: 1.89 s\n", | |
"Wall time: 99 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"# predicting\n", | |
"r = model_conv(x)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 2.98 s, sys: 783 ms, total: 3.76 s\n", | |
"Wall time: 213 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"# training\n", | |
"r = model_conv(x)\n", | |
"loss = criterion(r, np.squeeze(y))\n", | |
"optimizer.zero_grad()\n", | |
"loss.backward()\n", | |
"optimizer.step()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Keras" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Model definition and creation" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def concatenate_real_imag(x):\n", | |
" x_real = Lambda(tf.math.real)(x)\n", | |
" x_imag = Lambda(tf.math.imag)(x)\n", | |
" return concatenate([x_real, x_imag])\n", | |
"\n", | |
"def to_complex(x):\n", | |
" return tf.complex(x[0], x[1])\n", | |
"\n", | |
"def complex_from_half(x, n, output_shape):\n", | |
" return Lambda(lambda x: to_complex([x[..., :n], x[..., n:]]), output_shape=output_shape)(x)\n", | |
"\n", | |
"input_size = (320, None, 1)\n", | |
"kspace_input = Input(input_size, dtype='complex64', name='kspace_input')\n", | |
"inv_kspace = Lambda(ifft2d, output_shape=input_size)(kspace_input)\n", | |
"inv_kspace = concatenate_real_imag(inv_kspace)\n", | |
"inv_kspace = Conv2D(\n", | |
" 2,\n", | |
" 3,\n", | |
" activation='linear',\n", | |
" padding='same',\n", | |
" kernel_initializer='he_normal',\n", | |
")(inv_kspace)\n", | |
"inv_kspace = complex_from_half(inv_kspace, 1, input_size)\n", | |
"abs_inv_kspace = Lambda(tf.math.abs)(inv_kspace)\n", | |
"model_conv_keras = Model(inputs=kspace_input, outputs=abs_inv_kspace)\n", | |
"\n", | |
"model_conv_keras.compile(\n", | |
" optimizer='sgd',\n", | |
" loss='mse',\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Computations" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 6.07 s, sys: 377 ms, total: 6.45 s\n", | |
"Wall time: 4.12 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"# predicting\n", | |
"r = model_conv_keras.predict_on_batch(x_np)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 1/1\n", | |
"35/35 [==============================] - 4s 122ms/step - loss: 1.0144\n", | |
"CPU times: user 10.2 s, sys: 640 ms, total: 10.8 s\n", | |
"Wall time: 4.43 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"# training\n", | |
"r = model_conv_keras.fit(x_np, y_np, batch_size=35)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.8" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment