Last active
January 22, 2021 15:02
-
-
Save rozeappletree/6c8b9a3023ed62a603cf8cbec6efe01d to your computer and use it in GitHub Desktop.
Train_ MAPNet.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "Train_ MAPNet.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/rakesh4real/6c8b9a3023ed62a603cf8cbec6efe01d/train_-mapnet.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "EYmlIfEdUJMa", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "85c162ea-024c-453d-8220-9cff0870b94b" | |
}, | |
"source": [ | |
"!ln -sf /opt/bin/nvidia-smi /usr/bin/nvidia-smi\n", | |
"import subprocess\n", | |
"print(subprocess.getoutput('nvidia-smi'))" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Fri Jan 1 15:06:34 2021 \n", | |
"+-----------------------------------------------------------------------------+\n", | |
"| NVIDIA-SMI 418.67 Driver Version: 418.67 CUDA Version: 10.1 |\n", | |
"|-------------------------------+----------------------+----------------------+\n", | |
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", | |
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", | |
"|===============================+======================+======================|\n", | |
"| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n", | |
"| N/A 50C P8 10W / 70W | 0MiB / 15079MiB | 0% Default |\n", | |
"+-------------------------------+----------------------+----------------------+\n", | |
" \n", | |
"+-----------------------------------------------------------------------------+\n", | |
"| Processes: GPU Memory |\n", | |
"| GPU PID Type Process name Usage |\n", | |
"|=============================================================================|\n", | |
"| No running processes found |\n", | |
"+-----------------------------------------------------------------------------+\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "2YCApa2aeU9C", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "053f2af7-b8be-4b58-ad4d-9a9aa1b1edc9" | |
}, | |
"source": [ | |
"from tensorflow.python.client import device_lib\n", | |
"print(device_lib.list_local_devices())" | |
], | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"[name: \"/device:CPU:0\"\n", | |
"device_type: \"CPU\"\n", | |
"memory_limit: 268435456\n", | |
"locality {\n", | |
"}\n", | |
"incarnation: 9162827415679841530\n", | |
", name: \"/device:GPU:0\"\n", | |
"device_type: \"GPU\"\n", | |
"memory_limit: 14638920512\n", | |
"locality {\n", | |
" bus_id: 1\n", | |
" links {\n", | |
" }\n", | |
"}\n", | |
"incarnation: 2329053322952791945\n", | |
"physical_device_desc: \"device: 0, name: Tesla T4, pci bus id: 0000:00:04.0, compute capability: 7.5\"\n", | |
"]\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "MFGVP64SaGsg", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "7f7c9518-517b-4ed8-c983-c2698987dc12" | |
}, | |
"source": [ | |
"# Install required libs\n", | |
"\n", | |
"!git clone https://github.com/lehaifeng/MAPNet.git" | |
], | |
"execution_count": 2, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Cloning into 'MAPNet'...\n", | |
"remote: Enumerating objects: 3, done.\u001b[K\n", | |
"remote: Counting objects: 100% (3/3), done.\u001b[K\n", | |
"remote: Compressing objects: 100% (3/3), done.\u001b[K\n", | |
"remote: Total 103 (delta 0), reused 0 (delta 0), pack-reused 100\u001b[K\n", | |
"Receiving objects: 100% (103/103), 3.94 MiB | 41.64 MiB/s, done.\n", | |
"Resolving deltas: 100% (29/29), done.\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "loZqWTY1ecre", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 1000 | |
}, | |
"outputId": "da60509c-6661-4bf1-b4dd-2f17a1955250" | |
}, | |
"source": [ | |
"!pip install tensorflow==1.15.0\n", | |
"!pip install tensorflow-gpu==1.15.0\n", | |
"!pip install utils" | |
], | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Collecting tensorflow==1.15.0\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/3f/98/5a99af92fb911d7a88a0005ad55005f35b4c1ba8d75fba02df726cd936e6/tensorflow-1.15.0-cp36-cp36m-manylinux2010_x86_64.whl (412.3MB)\n", | |
"\u001b[K |████████████████████████████████| 412.3MB 39kB/s \n", | |
"\u001b[?25hRequirement already satisfied: keras-preprocessing>=1.0.5 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.15.0) (1.1.2)\n", | |
"Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.15.0) (1.1.0)\n", | |
"Requirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.15.0) (1.15.0)\n", | |
"Requirement already satisfied: numpy<2.0,>=1.16.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.15.0) (1.19.5)\n", | |
"Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.15.0) (3.3.0)\n", | |
"Requirement already satisfied: protobuf>=3.6.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.15.0) (3.12.4)\n", | |
"Requirement already satisfied: google-pasta>=0.1.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.15.0) (0.2.0)\n", | |
"Collecting keras-applications>=1.0.8\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/71/e3/19762fdfc62877ae9102edf6342d71b28fbfd9dea3d2f96a882ce099b03f/Keras_Applications-1.0.8-py3-none-any.whl (50kB)\n", | |
"\u001b[K |████████████████████████████████| 51kB 9.0MB/s \n", | |
"\u001b[?25hCollecting tensorboard<1.16.0,>=1.15.0\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/1e/e9/d3d747a97f7188f48aa5eda486907f3b345cd409f0a0850468ba867db246/tensorboard-1.15.0-py3-none-any.whl (3.8MB)\n", | |
"\u001b[K |████████████████████████████████| 3.8MB 56.2MB/s \n", | |
"\u001b[?25hCollecting tensorflow-estimator==1.15.1\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/de/62/2ee9cd74c9fa2fa450877847ba560b260f5d0fb70ee0595203082dafcc9d/tensorflow_estimator-1.15.1-py2.py3-none-any.whl (503kB)\n", | |
"\u001b[K |████████████████████████████████| 512kB 51.0MB/s \n", | |
"\u001b[?25hRequirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.15.0) (1.32.0)\n", | |
"Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.15.0) (0.36.2)\n", | |
"Requirement already satisfied: absl-py>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.15.0) (0.10.0)\n", | |
"Requirement already satisfied: wrapt>=1.11.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.15.0) (1.12.1)\n", | |
"Collecting gast==0.2.2\n", | |
" Downloading https://files.pythonhosted.org/packages/4e/35/11749bf99b2d4e3cceb4d55ca22590b0d7c2c62b9de38ac4a4a7f4687421/gast-0.2.2.tar.gz\n", | |
"Requirement already satisfied: astor>=0.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.15.0) (0.8.1)\n", | |
"Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.6.1->tensorflow==1.15.0) (51.3.3)\n", | |
"Requirement already satisfied: h5py in /usr/local/lib/python3.6/dist-packages (from keras-applications>=1.0.8->tensorflow==1.15.0) (2.10.0)\n", | |
"Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorboard<1.16.0,>=1.15.0->tensorflow==1.15.0) (3.3.3)\n", | |
"Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tensorboard<1.16.0,>=1.15.0->tensorflow==1.15.0) (1.0.1)\n", | |
"Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from markdown>=2.6.8->tensorboard<1.16.0,>=1.15.0->tensorflow==1.15.0) (3.3.0)\n", | |
"Requirement already satisfied: typing-extensions>=3.6.4; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < \"3.8\"->markdown>=2.6.8->tensorboard<1.16.0,>=1.15.0->tensorflow==1.15.0) (3.7.4.3)\n", | |
"Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < \"3.8\"->markdown>=2.6.8->tensorboard<1.16.0,>=1.15.0->tensorflow==1.15.0) (3.4.0)\n", | |
"Building wheels for collected packages: gast\n", | |
" Building wheel for gast (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Created wheel for gast: filename=gast-0.2.2-cp36-none-any.whl size=7540 sha256=d2b49c833459c411dd5782e3d6b916d2bde994c067fccdad3cfb9aae9c82f3bc\n", | |
" Stored in directory: /root/.cache/pip/wheels/5c/2e/7e/a1d4d4fcebe6c381f378ce7743a3ced3699feb89bcfbdadadd\n", | |
"Successfully built gast\n", | |
"\u001b[31mERROR: tensorflow-probability 0.12.1 has requirement gast>=0.3.2, but you'll have gast 0.2.2 which is incompatible.\u001b[0m\n", | |
"Installing collected packages: keras-applications, tensorboard, tensorflow-estimator, gast, tensorflow\n", | |
" Found existing installation: tensorboard 2.4.0\n", | |
" Uninstalling tensorboard-2.4.0:\n", | |
" Successfully uninstalled tensorboard-2.4.0\n", | |
" Found existing installation: tensorflow-estimator 2.4.0\n", | |
" Uninstalling tensorflow-estimator-2.4.0:\n", | |
" Successfully uninstalled tensorflow-estimator-2.4.0\n", | |
" Found existing installation: gast 0.3.3\n", | |
" Uninstalling gast-0.3.3:\n", | |
" Successfully uninstalled gast-0.3.3\n", | |
" Found existing installation: tensorflow 2.4.0\n", | |
" Uninstalling tensorflow-2.4.0:\n", | |
" Successfully uninstalled tensorflow-2.4.0\n", | |
"Successfully installed gast-0.2.2 keras-applications-1.0.8 tensorboard-1.15.0 tensorflow-1.15.0 tensorflow-estimator-1.15.1\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"application/vnd.colab-display-data+json": { | |
"pip_warning": { | |
"packages": [ | |
"gast", | |
"tensorboard", | |
"tensorflow" | |
] | |
} | |
} | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Collecting tensorflow-gpu==1.15.0\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/a5/ad/933140e74973fb917a194ab814785e7c23680ca5dee6d663a509fe9579b6/tensorflow_gpu-1.15.0-cp36-cp36m-manylinux2010_x86_64.whl (411.5MB)\n", | |
"\u001b[K |████████████████████████████████| 411.5MB 41kB/s \n", | |
"\u001b[?25hRequirement already satisfied: google-pasta>=0.1.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==1.15.0) (0.2.0)\n", | |
"Requirement already satisfied: tensorboard<1.16.0,>=1.15.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==1.15.0) (1.15.0)\n", | |
"Requirement already satisfied: wrapt>=1.11.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==1.15.0) (1.12.1)\n", | |
"Requirement already satisfied: astor>=0.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==1.15.0) (0.8.1)\n", | |
"Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==1.15.0) (0.36.2)\n", | |
"Requirement already satisfied: keras-applications>=1.0.8 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==1.15.0) (1.0.8)\n", | |
"Requirement already satisfied: gast==0.2.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==1.15.0) (0.2.2)\n", | |
"Requirement already satisfied: numpy<2.0,>=1.16.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==1.15.0) (1.19.5)\n", | |
"Requirement already satisfied: protobuf>=3.6.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==1.15.0) (3.12.4)\n", | |
"Requirement already satisfied: keras-preprocessing>=1.0.5 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==1.15.0) (1.1.2)\n", | |
"Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==1.15.0) (3.3.0)\n", | |
"Requirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==1.15.0) (1.32.0)\n", | |
"Requirement already satisfied: absl-py>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==1.15.0) (0.10.0)\n", | |
"Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==1.15.0) (1.1.0)\n", | |
"Requirement already satisfied: tensorflow-estimator==1.15.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==1.15.0) (1.15.1)\n", | |
"Requirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==1.15.0) (1.15.0)\n", | |
"Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard<1.16.0,>=1.15.0->tensorflow-gpu==1.15.0) (51.3.3)\n", | |
"Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tensorboard<1.16.0,>=1.15.0->tensorflow-gpu==1.15.0) (1.0.1)\n", | |
"Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorboard<1.16.0,>=1.15.0->tensorflow-gpu==1.15.0) (3.3.3)\n", | |
"Requirement already satisfied: h5py in /usr/local/lib/python3.6/dist-packages (from keras-applications>=1.0.8->tensorflow-gpu==1.15.0) (2.10.0)\n", | |
"Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from markdown>=2.6.8->tensorboard<1.16.0,>=1.15.0->tensorflow-gpu==1.15.0) (3.3.0)\n", | |
"Requirement already satisfied: typing-extensions>=3.6.4; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < \"3.8\"->markdown>=2.6.8->tensorboard<1.16.0,>=1.15.0->tensorflow-gpu==1.15.0) (3.7.4.3)\n", | |
"Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < \"3.8\"->markdown>=2.6.8->tensorboard<1.16.0,>=1.15.0->tensorflow-gpu==1.15.0) (3.4.0)\n", | |
"Installing collected packages: tensorflow-gpu\n", | |
"Successfully installed tensorflow-gpu-1.15.0\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"application/vnd.colab-display-data+json": { | |
"pip_warning": { | |
"packages": [ | |
"tensorflow" | |
] | |
} | |
} | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Collecting utils\n", | |
" Downloading https://files.pythonhosted.org/packages/55/e6/c2d2b2703e7debc8b501caae0e6f7ead148fd0faa3c8131292a599930029/utils-1.0.1-py2.py3-none-any.whl\n", | |
"Installing collected packages: utils\n", | |
"Successfully installed utils-1.0.1\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "R-DalbHfjrZ2" | |
}, | |
"source": [ | |
"import os\n", | |
"import time\n", | |
"import utils\n", | |
"import tensorflow as tf\n", | |
"import numpy as np\n", | |
"import skimage.io as io\n", | |
"import argparse\n", | |
"\n", | |
"from tensorflow.python.framework import ops\n", | |
"ops.reset_default_graph()" | |
], | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "3y4Isl-6cFn2" | |
}, | |
"source": [ | |
"#from google.colab import drive\n", | |
"#drive.mount('/content/drive')" | |
], | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "4YCjOIIQcFk5", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "def0deeb-0783-48ac-d0e8-bd97161ba5a2" | |
}, | |
"source": [ | |
"DATA_DIR = '/content/MAPNet/dataset/'\n", | |
"import os\n", | |
"# load repo with data if it is not exists\n", | |
"if not os.path.exists(DATA_DIR):\n", | |
" print('no data available')\n", | |
"else:\n", | |
" print('Done!!')" | |
], | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Done!!\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "_UmAQg4ASFP2" | |
}, | |
"source": [ | |
"# **Load data**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "GOCrdX8ybFZ8" | |
}, | |
"source": [ | |
"#load data\n", | |
"import numpy as np\n", | |
"import glob\n", | |
"import scipy\n", | |
"import random\n", | |
"import cv2\n", | |
"import skimage.io\n", | |
"\n", | |
"def load_batch(x, y):\n", | |
" x1 = []\n", | |
" y1 = []\n", | |
" for i in range(len(x)):\n", | |
" img = skimage.io.imread(x[i])\n", | |
" lab = skimage.io.imread(y[i])\n", | |
" #ret, lab = cv2.threshold(_lab,0,1,cv2.THRESH_BINARY)\n", | |
" img, lab = data_augmentation(img, lab)\n", | |
" lab = lab.reshape(512, 512, 1)\n", | |
" x1.append(img / 255.0)\n", | |
" y1.append(lab)\n", | |
" y1 = np.array(y1).astype(np.float32)\n", | |
" return x1, y1\n", | |
"\n", | |
"\n", | |
"def prepare_data():\n", | |
" \n", | |
" img = np.array(sorted(glob.glob(rf'{DATA_DIR}train/img/*.png')))\n", | |
" label = np.array(sorted(glob.glob(rf'{DATA_DIR}train/lab/*.png')))\n", | |
" test_img = np.array(sorted(glob.glob(rf'{DATA_DIR}test/img/*.png')))\n", | |
" test_label = np.array(sorted(glob.glob(rf'{DATA_DIR}test/lab/*.png')))\n", | |
"\n", | |
" print(f\"[DEBUG] train label {label}\")\n", | |
"\n", | |
" return img, label, test_img, test_label\n", | |
"\n", | |
"\n", | |
"def data_augmentation(image, label):\n", | |
" # Data augmentation\n", | |
" if random.randint(0, 1):\n", | |
" image = np.fliplr(image)\n", | |
" label = np.fliplr(label)\n", | |
" if random.randint(0, 1):\n", | |
" image = np.flipud(image)\n", | |
" label = np.flipud(label)\n", | |
"\n", | |
" if random.randint(0,1):\n", | |
" angle = random.randint(0, 3)*90\n", | |
" if angle!=0:\n", | |
" M = cv2.getRotationMatrix2D((image.shape[1] // 2, image.shape[0] // 2), angle, 1.0)\n", | |
" image = cv2.warpAffine(image, M, (image.shape[1], image.shape[0]), flags=cv2.INTER_NEAREST)\n", | |
" label = cv2.warpAffine(label, M, (label.shape[1], label.shape[0]), flags=cv2.INTER_NEAREST)\n", | |
"\n", | |
" return image, label" | |
], | |
"execution_count": 4, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "3IUD9qj4SI9j" | |
}, | |
"source": [ | |
"# **Load MapNet**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "m1Kfv7WxbjTU" | |
}, | |
"source": [ | |
"# mapnet\n", | |
"import tensorflow as tf\n", | |
"\n", | |
"# from keras.layers import UpSampling2D\n", | |
"\n", | |
"\n", | |
"def conv2d(input,filters,kernel_size=3,stride=1,padding='SAME'):\n", | |
" return tf.layers.conv2d(input,filters=filters,kernel_size=kernel_size,\n", | |
" padding=padding,strides=stride,use_bias=False,\n", | |
" kernel_initializer=tf.variance_scaling_initializer())\n", | |
"\n", | |
"\n", | |
"def bn(input,is_training=True):\n", | |
" return tf.layers.batch_normalization(input,momentum=0.99,epsilon=1e-3,training=is_training)\n", | |
"\n", | |
"\n", | |
"def bottleneck(x, size,is_training,downsampe=False):\n", | |
" residual = x\n", | |
" out = bn(x, is_training)\n", | |
" out = tf.nn.relu(out)\n", | |
" out = conv2d(out, size, 1, padding='VALID')\n", | |
" out = bn(out, is_training)\n", | |
" out = tf.nn.relu(out)\n", | |
" out = conv2d(out, size, 3)\n", | |
" out = bn(out, is_training)\n", | |
" out = tf.nn.relu(out)\n", | |
" out = conv2d(out, size * 4, 1, padding='VALID')\n", | |
"\n", | |
" if downsampe:\n", | |
" residual = bn(x, is_training)\n", | |
" residual = tf.nn.relu(residual)\n", | |
" residual = conv2d(residual, size * 4, 1, padding='VALID')\n", | |
" out = tf.add(out,residual)\n", | |
" return out\n", | |
"\n", | |
"\n", | |
"def resblock(x, size,is_training):\n", | |
" residual = x\n", | |
"\n", | |
" out = bn(x, is_training)\n", | |
" out = tf.nn.relu(out)\n", | |
" out = conv2d(out, size, 3)\n", | |
" out = bn(out, is_training)\n", | |
" out = tf.nn.relu(out)\n", | |
" out = conv2d(out, size, 3)\n", | |
"\n", | |
" out = tf.add(out, residual)\n", | |
" return out\n", | |
"\n", | |
"\n", | |
"def stage0(x,is_training):\n", | |
" x = bottleneck(x, 64,is_training, downsampe=True)\n", | |
" x = bottleneck(x, 64,is_training)\n", | |
" x = bottleneck(x, 64,is_training)\n", | |
" x = bottleneck(x, 64,is_training)\n", | |
" return x\n", | |
"\n", | |
"\n", | |
"def translayer(x, in_channels, out_channels,is_training):\n", | |
" num_in = len(in_channels)\n", | |
" num_out = len(out_channels)\n", | |
" out = []\n", | |
" for i in range(num_out):\n", | |
" if i < num_in:\n", | |
" residual = bn(x[i], is_training)\n", | |
" residual = tf.nn.relu(residual)\n", | |
" residual = conv2d(residual, out_channels[i], 3)\n", | |
" out.append(residual)\n", | |
" else:\n", | |
" residual = bn(x[-1], is_training)\n", | |
" residual = tf.nn.relu(residual)\n", | |
" residual = conv2d(residual, out_channels[i], 3, stride=2)\n", | |
" out.append(residual)\n", | |
" return out\n", | |
"\n", | |
"\n", | |
"def convb(x, block_num, channels,is_training):\n", | |
" out = []\n", | |
" for i in range(len(channels)):\n", | |
" residual = x[i]\n", | |
" for j in range(block_num):\n", | |
" residual = resblock(residual, channels[i],is_training)\n", | |
" out.append(residual)\n", | |
" return out\n", | |
"\n", | |
"\n", | |
"def featfuse(x, channels, is_training, multi_scale_output=True):\n", | |
" out = []\n", | |
" for i in range(len(channels) if multi_scale_output else 1):\n", | |
" residual = x[i]\n", | |
" for j in range(len(channels)):\n", | |
" if j > i:\n", | |
" if multi_scale_output == False:\n", | |
" y = bn(x[j], is_training)\n", | |
" y = tf.nn.relu(y)\n", | |
" y = conv2d(y, channels[j], 1, padding='VALID')\n", | |
" out.append(tf.keras.layers.UpSampling2D(size=2 ** (j - i))(y))\n", | |
" else:\n", | |
" y = bn(x[j], is_training)\n", | |
" y = tf.nn.relu(y)\n", | |
" y = conv2d(y, channels[i], 1, padding='VALID')\n", | |
" y = tf.keras.layers.UpSampling2D(size=2 ** (j - i))(y)\n", | |
" residual = tf.add(residual, y)\n", | |
"\n", | |
" elif j < i:\n", | |
" y = x[j]\n", | |
" for k in range(i - j):\n", | |
" if k == i - j - 1:\n", | |
" y = bn(y, is_training)\n", | |
" y = tf.nn.relu(y)\n", | |
" y = conv2d(y, channels[i], 1)\n", | |
" y = tf.layers.max_pooling2d(y, 2, 2)\n", | |
"\n", | |
" else:\n", | |
" y = bn(y, is_training)\n", | |
" y = tf.nn.relu(y)\n", | |
" y = conv2d(y, channels[j], 1)\n", | |
" y = tf.layers.max_pooling2d(y, 2, 2)\n", | |
"\n", | |
" residual = tf.add(residual, y)\n", | |
" out.append(residual)\n", | |
" return out\n", | |
"\n", | |
"\n", | |
"def convblock(x, channels,is_training, multi_scale_output=True):\n", | |
" residual = convb(x, 4, channels,is_training)\n", | |
" out = featfuse(residual, channels,is_training, multi_scale_output=multi_scale_output)\n", | |
" return out\n", | |
"\n", | |
"\n", | |
"def stage(x, num_modules, channels, is_training,multi_scale_output=True):\n", | |
" out = x\n", | |
" for i in range(num_modules):\n", | |
" if i == num_modules - 1 and multi_scale_output == False:\n", | |
" out = convblock(out, channels,is_training, multi_scale_output=False)\n", | |
" else:\n", | |
" out = convblock(out, channels,is_training)\n", | |
" return out\n", | |
"\n", | |
"\n", | |
"def pyramid_pooling_block(input, bin_sizes):\n", | |
" pool_list = []\n", | |
" h = input.shape[1]\n", | |
" c = input.shape[-1]\n", | |
" for bin_size in bin_sizes:\n", | |
" pool1 = tf.layers.average_pooling2d(input, (h // bin_size, h // bin_size), (h // bin_size, h // bin_size))\n", | |
" pool1 = conv2d(pool1, int(c)//4, 1)\n", | |
" pool1 = tf.image.resize_bilinear(pool1, (h, h))\n", | |
" pool_list.append(pool1)\n", | |
" pool = tf.concat(pool_list, axis=3)\n", | |
" return tf.add(input, pool)\n", | |
"\n", | |
"\n", | |
"def spatial_pooling(input):\n", | |
" h,w=input.shape[1],input.shape[2]\n", | |
" p1=tf.image.resize_bilinear(tf.layers.max_pooling2d(input,2,2),(h,w))\n", | |
" p2 = tf.image.resize_bilinear(tf.layers.max_pooling2d(input, 3, 3), (h, w))\n", | |
" p3=tf.image.resize_bilinear(tf.layers.max_pooling2d(input,5,5),(h,w))\n", | |
" p4 = tf.image.resize_bilinear(tf.layers.max_pooling2d(input, 6, 6), (h, w))\n", | |
" p=tf.concat([p1,p2,p3,p4,input],axis=-1)\n", | |
" return p\n", | |
"\n", | |
"\n", | |
"def channel_squeeze(input,filters,name=\" \"):\n", | |
" with tf.name_scope(name):\n", | |
" squeeze=tf.reduce_mean(input,axis=[1,2])\n", | |
" with tf.name_scope(name+\"fc1\"):\n", | |
" fc1=tf.layers.dense(squeeze,use_bias=True,units=filters)\n", | |
" fc1=tf.nn.relu(fc1)\n", | |
" with tf.name_scope(name+\"fc2\"):\n", | |
" fc2=tf.layers.dense(fc1,use_bias=True,units=filters)\n", | |
" fc2=tf.nn.sigmoid(fc2)\n", | |
" result=tf.reshape(fc2,[-1,1,1,filters])\n", | |
" return input*result\n", | |
"\n", | |
"\n", | |
"def mapnet(input, is_training=True):\n", | |
" channels_s2 = [64, 128]\n", | |
" channels_s3 = [64, 128, 256]\n", | |
" num_modules_s2 = 2\n", | |
" num_modules_s3 = 3\n", | |
"\n", | |
" conv_1 = conv2d(input, 64, stride=2)\n", | |
" conv_1 = bn(conv_1, is_training)\n", | |
" conv_1 = tf.nn.relu(conv_1)\n", | |
" conv_2 = conv2d(conv_1, 64)\n", | |
" conv_2 = bn(conv_2, is_training)\n", | |
" conv_2 = tf.nn.relu(conv_2)\n", | |
" conv_3 = conv2d(conv_2, 64)\n", | |
" conv_3 = bn(conv_3, is_training)\n", | |
" conv_3 = tf.nn.relu(conv_3)\n", | |
" conv_4 = tf.layers.max_pooling2d(conv_3, 2, 2)\n", | |
"\n", | |
" stage1 = stage0(conv_4,is_training)\n", | |
" trans1 = translayer([stage1], [256], channels_s2,is_training)\n", | |
" stage2 = stage(trans1, num_modules_s2, channels_s2,is_training)\n", | |
" trans2 = translayer(stage2, channels_s2, channels_s3,is_training)\n", | |
" stage3 = stage(trans2, num_modules_s3, channels_s3,is_training,multi_scale_output=False)\n", | |
"\n", | |
" stg3=tf.concat(stage3,axis=-1)\n", | |
" squeeze=channel_squeeze(stg3, stg3.shape[-1], name=\"squeeze\")\n", | |
"\n", | |
" spatial=tf.concat([stage3[0],stage3[1]],axis=-1)\n", | |
" # spatial=pyramid_pooling_block(spatial, [1, 2, 4, 8])\n", | |
" spatial=spatial_pooling(spatial)\n", | |
"\n", | |
" new_feature = tf.concat([spatial, squeeze], axis=-1)\n", | |
" new_feature = bn(new_feature, is_training)\n", | |
" new_feature = tf.nn.relu(new_feature)\n", | |
" result=conv2d(new_feature, 128, 1, padding='SAME')\n", | |
"\n", | |
" up1=tf.image.resize_bilinear(result,size=(stage3[0].shape[1]*2,stage3[0].shape[2]*2))\n", | |
" up1 = bn(up1, is_training)\n", | |
" up1 = tf.nn.relu(up1)\n", | |
" up1 = conv2d(up1, 64, 3)\n", | |
"\n", | |
" up2 = tf.image.resize_bilinear(up1, size=(up1.shape[1]*2, up1.shape[2]*2))\n", | |
" up2 = bn(up2, is_training)\n", | |
" up2 = tf.nn.relu(up2)\n", | |
" up2 = conv2d(up2, 32, 3)\n", | |
"\n", | |
" up2 = bn(up2, is_training)\n", | |
" up2 = tf.nn.relu(up2)\n", | |
" final = conv2d(up2, 1, 1, padding='valid')\n", | |
"\n", | |
" return final" | |
], | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "UFEOU_1YSW4P" | |
}, | |
"source": [ | |
"# **Train**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "RF81kkV8by-h", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "525f37be-3681-4d9e-abe3-b2777b0aae93" | |
}, | |
"source": [ | |
"parser = argparse.ArgumentParser()\n", | |
"parser.add_argument('--batch_size', type=int, default=4, help='Number of images in each batch')\n", | |
"parser.add_argument('--learning_rate', type=float, default=0.001, help='Number of images in each batch')\n", | |
"parser.add_argument('--crop_height', type=int, default=512, help='Height of cropped input image to network')\n", | |
"parser.add_argument('--crop_width', type=int, default=512, help='Width of cropped input image to network')\n", | |
"parser.add_argument('--clip_size', type=int, default=450, help='Width of cropped input image to network')\n", | |
"parser.add_argument('--num_epochs', type=int, default=200, help='Number of epochs to train for')\n", | |
"parser.add_argument('--h_flip', type=bool, default=True, help='Whether to randomly flip the image horizontally for data augmentation')\n", | |
"parser.add_argument('--v_flip', type=bool, default=True, help='Whether to randomly flip the image vertically for data augmentation')\n", | |
"parser.add_argument('--color', type=bool, default=True, help='Whether to randomly flip the image vertically for data augmentation')\n", | |
"parser.add_argument('--rotation', type=bool, default=True, help='randomly rotate, the imagemax rotation angle in degrees.')\n", | |
"parser.add_argument('--start_valid', type=int, default=20, help='Number of epoch to valid')\n", | |
"parser.add_argument('--valid_step', type=int, default=1, help=\"Number of step to validation\")\n", | |
"parser.add_argument('-f')\n", | |
"\n", | |
"args = parser.parse_args()\n", | |
"num_images=[]\n", | |
"train_img, train_label,valid_img,valid_lab= prepare_data()\n", | |
"num_batches=len(train_img)//(args.batch_size)" | |
], | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"[DEBUG] train label ['/content/MAPNet/dataset/train/lab/7000000.png'\n", | |
" '/content/MAPNet/dataset/train/lab/7000001.png'\n", | |
" '/content/MAPNet/dataset/train/lab/7000002.png'\n", | |
" '/content/MAPNet/dataset/train/lab/7000003.png'\n", | |
" '/content/MAPNet/dataset/train/lab/7000006.png'\n", | |
" '/content/MAPNet/dataset/train/lab/7000007.png']\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "j0EaxcHnkB7z", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "38319ec3-08e0-40e7-bf4f-fa8eb4193683" | |
}, | |
"source": [ | |
"import tensorflow.compat.v1 as tf\n", | |
"tf.disable_v2_behavior()" | |
], | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/compat/v2_compat.py:68: disable_resource_variables (from tensorflow.python.ops.variable_scope) is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"non-resource variables are not supported in the long term\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "OEdHFizTg9KZ", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "4de1ff73-46de-4dfc-daf3-dcf775c50dd4" | |
}, | |
"source": [ | |
"img=tf.placeholder(tf.float32,[None,args.crop_height,args.crop_width,3])\n", | |
"is_training=tf.placeholder(tf.bool)\n", | |
"label=tf.placeholder(tf.float32,[None,args.crop_height,args.crop_height,1])\n", | |
"\n", | |
"pred=mapnet(img,is_training)\n", | |
"pred1=tf.nn.sigmoid(pred)\n", | |
"\n", | |
"update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)\n", | |
"with tf.control_dependencies(update_ops):\n", | |
"\n", | |
" sig=tf.nn.sigmoid_cross_entropy_with_logits(labels=label, logits=pred)\n", | |
" sigmoid_cross_entropy_loss = tf.reduce_mean(sig)\n", | |
" train_step = tf.train.AdamOptimizer(args.learning_rate).minimize(sigmoid_cross_entropy_loss)\n", | |
"saver=tf.train.Saver(var_list=tf.global_variables())" | |
], | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"WARNING:tensorflow:From <ipython-input-5-474ff2aaaabc>:10: conv2d (from tensorflow.python.layers.convolutional) is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Use `tf.keras.layers.Conv2D` instead.\n", | |
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/layers/convolutional.py:424: Layer.apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Please use `layer.__call__` method instead.\n", | |
"WARNING:tensorflow:From <ipython-input-5-474ff2aaaabc>:14: batch_normalization (from tensorflow.python.layers.normalization) is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Use keras.layers.BatchNormalization instead. In particular, `tf.control_dependencies(tf.GraphKeys.UPDATE_OPS)` should not be used (consult the `tf.keras.layers.batch_normalization` documentation).\n", | |
"WARNING:tensorflow:From <ipython-input-5-474ff2aaaabc>:192: max_pooling2d (from tensorflow.python.layers.pooling) is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Use keras.layers.MaxPooling2D instead.\n", | |
"WARNING:tensorflow:From <ipython-input-5-474ff2aaaabc>:168: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Use keras.layers.Dense instead.\n", | |
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/nn_impl.py:183: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Use tf.where in 2.0, which has the same broadcast rule as np.where\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Gfv4V3AOxvgr" | |
}, | |
"source": [ | |
"history = {\n", | |
" 'train': {\n", | |
" \"iter\" : [],\n", | |
" \"iou\" : [],\n", | |
" \"loss\" : []\n", | |
" },\n", | |
" 'val': {\n", | |
" \"iter\" : [],\n", | |
" \"loss\" : [], \n", | |
" \"iou\" : []\n", | |
" }\n", | |
"}\n", | |
"\n", | |
"CHECKPOINTS_DIR = './'\n", | |
"PRINT_EVERY = 3 # epochs" | |
], | |
"execution_count": 15, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "fbvtMhxJd1ec" | |
}, | |
"source": [ | |
"from tqdm import tqdm\n", | |
"\n", | |
"def load():\n", | |
" import re\n", | |
" print(\"[INFO] Reading checkpoints dir...\")\n", | |
" checkpoint_dir = CHECKPOINTS_DIR\n", | |
" \n", | |
" ckpt = tf.train.get_checkpoint_state(checkpoint_dir)\n", | |
" if ckpt and ckpt.model_checkpoint_path:\n", | |
" ckpt_name = os.path.basename(ckpt.model_checkpoint_path)\n", | |
" saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name))\n", | |
" counter = int(next(re.finditer(\"(\\d+)(?!.*\\d)\",ckpt_name)).group(0))\n", | |
" print(\"[INFO] Checkpoint {} read successed\".format(ckpt_name))\n", | |
" return True, counter\n", | |
" else:\n", | |
" print(\"[INFO] Checkpoint not found\")\n", | |
" return False, 0\n", | |
"\n", | |
"def train():\n", | |
"\n", | |
" tf.global_variables_initializer().run()\n", | |
"\n", | |
" could_load, checkpoint_counter = load()\n", | |
" if could_load:\n", | |
" start_epoch = (int)(checkpoint_counter / num_batches)\n", | |
" start_batch_id = checkpoint_counter - start_epoch * num_batches\n", | |
" counter = checkpoint_counter\n", | |
" print(\"[INFO] Checkpoint Load Success!\")\n", | |
"\n", | |
" else:\n", | |
" start_epoch = 0\n", | |
" start_batch_id = 0\n", | |
" counter = 1\n", | |
" print(\"[INFO] Checkpoint load failed. Training from scratch...\")\n", | |
"\n", | |
" train_iter=[]\n", | |
" train_loss=[]\n", | |
" IOU=0.65\n", | |
"\n", | |
" print(\"==================================================================\")\n", | |
" print(\"[INFO] GENERAL INFORMATION\")\n", | |
" print(\"==================================================================\")\n", | |
" # utils.count_params()\n", | |
" print(\"Total train image:{}\".format(len(train_img)))\n", | |
" print(\"Total validate image:{}\".format(len(valid_img)))\n", | |
" print(\"Total epoch:{}\".format(args.num_epochs))\n", | |
" print(\"Batch size:{}\".format(args.batch_size))\n", | |
" print(\"Learning rate:{}\".format(args.learning_rate))\n", | |
" #print(\"Checkpoint step:{}\".format(args.checkpoint_step))\n", | |
"\n", | |
" print(\"==================================================================\")\n", | |
" print(\"[INFO] DATA AUGMENTATION\")\n", | |
" print(\"==================================================================\")\n", | |
" print(\"h_flip: {}\".format(args.h_flip))\n", | |
" print(\"v_flip: {}\".format(args.v_flip))\n", | |
" print(\"rotate: {}\".format(args.rotation))\n", | |
" print(\"clip size: {}\".format(args.clip_size))\n", | |
"\n", | |
" print(\"==================================================================\")\n", | |
" print(\"[INFO] TRAINING STARTED\")\n", | |
" print(\"==================================================================\")\n", | |
"\n", | |
" loss_tmp = []\n", | |
" \n", | |
" # -----------------------------------------------------------------------------------\n", | |
" # beg: epoch\n", | |
" # -----------------------------------------------------------------------------------\n", | |
" args.num_epochs = 10\n", | |
" args.start_valid = 0\n", | |
" for i in range(start_epoch, args.num_epochs):\n", | |
" \n", | |
"\n", | |
" # -------------------------------------------------------------------------------------------------------\n", | |
" # beg: batches\n", | |
" # -------------------------------------------------------------------------------------------------------\n", | |
" epoch_time=time.time()\n", | |
" id_list = np.random.permutation(len(train_img))\n", | |
" batch_pbar = tqdm(range(start_batch_id, num_batches), desc=f\"[TRAIN] Epoch {i}\")\n", | |
" for j in batch_pbar:\n", | |
"\n", | |
" img_d = []\n", | |
" lab_d = []\n", | |
" for ind in range(args.batch_size):\n", | |
" id = id_list[j * args.batch_size + ind]\n", | |
" img_d.append(train_img[id])\n", | |
" lab_d.append(train_label[id])\n", | |
" x_batch, y_batch = load_batch(img_d, lab_d)\n", | |
" # print(f\"[DEBUG] {x_batch[0].shape} {y_batch[0].shape}\")\n", | |
" # (512, 512, 3) (512, 512, 1)\n", | |
"\n", | |
" feed_dict = {img: x_batch,\n", | |
" label: y_batch,\n", | |
" is_training:True}\n", | |
"\n", | |
" _, loss, pred1 = sess.run([train_step, sigmoid_cross_entropy_loss, pred], feed_dict=feed_dict)\n", | |
"\n", | |
" loss_tmp.append(loss)\n", | |
" if (j == num_batches-1):\n", | |
" tmp = np.median(loss_tmp)\n", | |
" history['val']['iter'].appned(i)\n", | |
" history['val']['iou'].append(0.2)\n", | |
" history['val']['loss'].append(tmp)\n", | |
" #train_iter.append(counter)\n", | |
" #train_loss.append(tmp)\n", | |
" #print('Epoch', i, '|Iter', counter, '|Loss', tmp)\n", | |
" batch_pbar.set_description(f\"[TRAIN] Epoch {i} --- Iter {counter} --- Loss {tmp}\")\n", | |
" loss_tmp.clear()\n", | |
"\n", | |
" counter += 1\n", | |
" start_batch_id = 0\n", | |
" # print(f'[DEBUG] Time taken for epoch {i}: {time.time() - epoch_time:.3f} seconds')\n", | |
" # saver.save(sess, './ckeckpoint_10epoch_new/model.ckpt', global_step=counter)\n", | |
" # -------------------------------------------------------------------------------------------------------\n", | |
" # end: batches\n", | |
" # -------------------------------------------------------------------------------------------------------\n", | |
"\n", | |
"\n", | |
" # -------------------------------------------------------------------------------------------------------\n", | |
" # beg: val for epoch\n", | |
" # -------------------------------------------------------------------------------------------------------\n", | |
" if (i>args.start_valid):\n", | |
" if (i-args.start_valid)%args.valid_step==0:\n", | |
" val_iou, val_loss = validation()\n", | |
" #print(f\"[INFO] current val loss: {val_loss}\")\n", | |
" #print(f\"[INFO] last iou valu: {IOU}\")\n", | |
" #print(f\"[INFO] new_iou value: {val_iou}\")\n", | |
" history['val']['iter'].appned(i)\n", | |
" history['val']['iou'].append(val_iou)\n", | |
" history['val']['loss'].append(val_loss)\n", | |
" # saving best model based on best IOU score.\n", | |
" # Can do based on best val_loss instead too!\n", | |
" if val_iou > IOU:\n", | |
" print(f\"[INFO] Saving best model as checkpoint... val_iou: {val_iou}\")\n", | |
" saver.save(sess, f'{CHECKPOINTS_DIR}model.ckpt', global_step=counter, write_meta_graph=True)\n", | |
" IOU = val_iou\n", | |
" # -------------------------------------------------------------------------------------------------------\n", | |
" # end: val for epoch\n", | |
" # -------------------------------------------------------------------------------------------------------\n", | |
"\n", | |
"\n", | |
"\n", | |
" # -----------------------------------------------------------------------------------\n", | |
" # end: epoch\n", | |
" # -----------------------------------------------------------------------------------\n", | |
" saver.save(sess, f'{CHECKPOINTS_DIR}model.ckpt', global_step=counter)\n", | |
"\n", | |
"\n", | |
"\n", | |
"def f_iou(predict, label):\n", | |
"\n", | |
" tp = np.sum(np.logical_and(predict == 1, label == 1))\n", | |
" fp = np.sum(predict==1)\n", | |
" fn = np.sum(label == 1)\n", | |
" return tp,fp+fn-tp\n", | |
"\n", | |
"\n", | |
"\n", | |
"def validation():\n", | |
"\n", | |
" #print(\"[INFO] Validating ...\")\n", | |
" inter=0\n", | |
" unin=0\n", | |
" loss_accumulator = []\n", | |
"\n", | |
" batch_pbar = tqdm(range(0,len(valid_img)), desc=f\"Validating -- \")\n", | |
" for j in batch_pbar:\n", | |
" x_batch = valid_img[j]\n", | |
" x_batch = io.imread(x_batch) / 255.0\n", | |
" x_batch = np.expand_dims(x_batch, axis=0)\n", | |
" y_actual_batch = np.expand_dims(io.imread(valid_lab[j]), axis=0)\n", | |
" # print(f\"[DEBUG] {x_batch.shape} {y_actual_batch.shape}\")\n", | |
" # (1, 512, 512, 3) (1, 512, 512) \n", | |
" y_actual_batch = np.expand_dims(y_actual_batch, axis=-1) \n", | |
" # (1, 512, 512) > (1, 512, 512, 1)\n", | |
"\n", | |
" feed_dict = {img: x_batch,\n", | |
" label: y_actual_batch,\n", | |
" is_training:False}\n", | |
"\n", | |
" #predict = sess.run(pred1, feed_dict=feed_dict)\n", | |
" _, loss, predict = sess.run([train_step, sigmoid_cross_entropy_loss, pred1], feed_dict=feed_dict)\n", | |
" loss_accumulator.append(loss)\n", | |
" \n", | |
" predict[predict < 0.5] = 0\n", | |
" predict[predict >= 0.5] = 1\n", | |
" result = np.squeeze(predict)\n", | |
" gt_value=io.imread(valid_lab[j])\n", | |
" intr,unn=f_iou(gt_value,result)\n", | |
"\n", | |
" inter=inter+intr\n", | |
" unin=unin+unn\n", | |
"\n", | |
" if j == (len(valid_img) - 1):\n", | |
" batch_pbar.set_description(f\"[VALID] --- Loss {np.median(loss_accumulator):.4f} --- IOU {(inter*1.0)/(unin+1e-10):.4f}\")\n", | |
"\n", | |
" iou = (inter*1.0)/(unin+1e-10)\n", | |
" loss = np.median(loss_accumulator)\n", | |
" return iou, loss\n" | |
], | |
"execution_count": 16, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "qH24Mi-ceC67", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "45eca95c-953e-43e2-d3b3-437fd1dca060" | |
}, | |
"source": [ | |
"with tf.Session() as sess:\n", | |
" train()" | |
], | |
"execution_count": 18, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"[INFO] Reading checkpoints dir...\n", | |
"INFO:tensorflow:Restoring parameters from ./model.ckpt-11\n", | |
"[INFO] Checkpoint model.ckpt-11 read successed\n", | |
"[INFO] Checkpoint Load Success!\n", | |
"==================================================================\n", | |
"[INFO] GENERAL INFORMATION\n", | |
"==================================================================\n", | |
"Total train image:6\n", | |
"Total validate image:5\n", | |
"Total epoch:10\n", | |
"Batch size:4\n", | |
"Learning rate:0.001\n", | |
"==================================================================\n", | |
"[INFO] DATA AUGMENTATION\n", | |
"==================================================================\n", | |
"h_flip: True\n", | |
"v_flip: True\n", | |
"rotate: True\n", | |
"clip size: 450\n", | |
"==================================================================\n", | |
"[INFO] TRAINING STARTED\n", | |
"==================================================================\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "zR4gtIDtz6Iq" | |
}, | |
"source": [ | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"def plot_curves():\n", | |
" fig = plt.figure(figsize=(10,7))\n", | |
" \n", | |
" plt.plot(history['val']['loss'], label='Val Loss')\n", | |
" plt.plot(history['train']['loss'], label='Train Loss')\n", | |
"\n", | |
" plt.xlabel(\"Loss\")\n", | |
" plt.ylabel(\"Epoch Number\")\n", | |
" plt.legend()\n", | |
" plt.grid()\n", | |
" plt.show()" | |
], | |
"execution_count": 13, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 442 | |
}, | |
"id": "D8DwQ4HH4b-i", | |
"outputId": "7708037c-7943-4cb1-f33d-ebe54f8e219b" | |
}, | |
"source": [ | |
"plot_curves()" | |
], | |
"execution_count": 14, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 720x504 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"tags": [], | |
"needs_background": "light" | |
} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "D-l6EXE_knU_" | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment