Skip to content

Instantly share code, notes, and snippets.

@ucalyptus
Created March 16, 2019 11:53
Show Gist options
  • Save ucalyptus/8d612c173fe384e1a1e5b52b9ab59990 to your computer and use it in GitHub Desktop.
Save ucalyptus/8d612c173fe384e1a1e5b52b9ab59990 to your computer and use it in GitHub Desktop.
Unets_pytorch.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Unets_pytorch.ipynb",
"version": "0.3.2",
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/ucalyptus/8d612c173fe384e1a1e5b52b9ab59990/unets_pytorch.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"metadata": {
"id": "2_5Oj3jyEpfL",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"import torch \n",
"import torch.nn as nn"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "NcaWeE1LRMOr",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"def conv_block(in_dim, out_dim, act_fn):\n",
" model = nn.Sequential(\n",
" nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=1, padding=1),\n",
" nn.BatchNorm2d(out_dim),\n",
" act_fn,)\n",
" return model\n"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "ANvv49V5TPD_",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"def conv_trans_block(in_dim, out_dim, act_fn):\n",
" model = nn.Sequential(\n",
" nn.ConvTranspose2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
" nn.BatchNorm2d(out_dim),\n",
" act_fn,)\n",
" return model\n"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "sOYu4hBDTR0e",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"\n",
"def conv_block_2(in_dim, out_dim, act_fn):\n",
" model = nn.Sequential(\n",
" conv_block(in_dim, out_dim, act_fn),\n",
" nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1),\n",
" nn.BatchNorm2d(out_dim),\n",
" )\n",
" return model"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "6__Nb-Y5TZZR",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"def conv_block_3(in_dim, out_dim, act_fn):\n",
" model = nn.Sequential(\n",
" conv_block(in_dim, out_dim, act_fn),\n",
" conv_block(out_dim, out_dim, act_fn),\n",
" nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1),\n",
" nn.BatchNorm2d(out_dim),\n",
" )\n",
" return model"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "w6FREVL-ThiE",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"class UNet(nn.Module):\n",
" def __init__(self, in_dim, out_dim, num_filter):\n",
" super(UNet, self).__init__()\n",
" self.in_dim = in_dim\n",
" self.out_dim = out_dim\n",
" self.num_filter = num_filter\n",
" encoder_act_fn = nn.LeakyReLU(0.2, inplace=True)\n",
" decoder_act_fn = nn.RELU()\n",
"\n",
" self.down_1 = conv_block_2(self.in_dim, self.num_filter, encoder_act_fn) \n",
" self.down_2 = conv_block_2(self.num_filter * 1, self.num_filter * 2, encoder_act_fn)\n",
" self.down_3 = conv_block_2(self.num_filter * 2, self.num_filter * 4, encoder_act_fn)\n",
" self.down_4 = conv_block_2(self.num_filter * 4, self.num_filter * 8, encoder_act_fn)\n",
" self.bridge = conv_block_2(self.num_filter * 8, self.num_filter * 16, encoder_act_fn)\n",
"\n",
" self.trans_1 = conv_trans_block(self.num_filter * 16, self.num_filter * 8, decoder_act_fn)\n",
" self.up_1 = conv_block_2(self.num_filter * 16, self.num_filter * 8, decoder_act_fn)\n",
" self.trans_2 = conv_trans_block(self.num_filter * 8, self.num_filter * 4, decoder_act_fn)\n",
" self.up_2 = conv_block_2(self.num_filter * 8, self.num_filter * 4, decoder_act_fn)\n",
" self.trans_3 = conv_trans_block(self.num_filter * 4, self.num_filter * 2, decoder_act_fn)\n",
" self.up_3 = conv_block_2(self.num_filter * 4, self.num_filter * 2, decoder_act_fn)\n",
" self.trans_4 = conv_trans_block(self.num_filter * 2, self.num_filter * 1, decoder_act_fn)\n",
" self.up_4 = conv_block_2(self.num_filter * 2, self.num_filter * 1, decoder_act_fn)\n",
"\n",
" self.out = nn.Sequential(nn.Conv2d(self.num_filter, self.out_dim, 3, 1, 1), nn.Tanh())\n",
"\n",
" def forward(self, input):\n",
" down_1 = self.down_1(input)\n",
" \n",
" down_2 = self.down_2(down_1)\n",
" \n",
" down_3 = self.down_3(down_2)\n",
" \n",
" down_4 = self.down_4(down_3)\n",
" \n",
"\n",
" bridge = self.bridge(down_4)\n",
"\n",
" trans_1 = self.trans_1(bridge)\n",
" concat_1 = torch.cat([trans_1, down_4], dim=1)\n",
" up_1 = self.up_1(concat_1)\n",
" trans_2 = self.trans_2(up_1)\n",
" concat_2 = torch.cat([trans_2, down_3], dim=1)\n",
" up_2 = self.up_2(concat_2)\n",
" trans_3 = self.trans_3(up_2)\n",
" concat_3 = torch.cat([trans_3, down_2], dim=1)\n",
" up_3 = self.up_3(concat_3)\n",
" trans_4 = self.trans_4(up_3)\n",
" concat_4 = torch.cat([trans_4, down_1], dim=1)\n",
" up_4 = self.up_4(concat_4)\n",
"\n",
" out = self.out(up_4)\n",
"\n",
" return out"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "XSjvGJlBT-W2",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
""
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "jJxBmpQRT_to",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment