Created
May 17, 2021 20:25
-
-
Save sidneyarcidiacono/534e45f660a0aed62abb5096485c53de to your computer and use it in GitHub Desktop.
PROTEINS_Embedding.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "name": "PROTEINS_Embedding.ipynb", | |
| "provenance": [], | |
| "collapsed_sections": [], | |
| "mount_file_id": "1W0syoaSSP-MpFhYC14cBV0O32Esyr_zF", | |
| "authorship_tag": "ABX9TyO4Z9AYEuRzgb85dO5EB4Ch", | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| }, | |
| "language_info": { | |
| "name": "python" | |
| } | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/sidneyarcidiacono/534e45f660a0aed62abb5096485c53de/proteins_embedding.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "7lqL-LQstoi0" | |
| }, | |
| "source": [ | |
| "# Graph Neural Network Classification on the PROTEINS Dataset\n", | |
| "\n", | |
| "For the first approach, I'm going to use [Spektral](https://graphneural.network/getting-started/) for Python to build my GCN layer and then perform our classification. \n", | |
| "\n", | |
| "Spektral is a library for Python for Graph Neural Networks, built on Tensorflow and Keras. \n", | |
| "\n", | |
| "Our second experiment will be built with PyTorch Geometric." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "Q_884Zbltkwo", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "outputId": "b0639189-f59f-47c7-83c7-2bd39918e693" | |
| }, | |
| "source": [ | |
| "# Uncomment me and run!\n", | |
| "# !pip install spektral" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "Requirement already satisfied: spektral in /usr/local/lib/python3.7/dist-packages (1.0.6)\n", | |
| "Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from spektral) (4.41.1)\n", | |
| "Requirement already satisfied: lxml in /usr/local/lib/python3.7/dist-packages (from spektral) (4.2.6)\n", | |
| "Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from spektral) (1.4.1)\n", | |
| "Requirement already satisfied: networkx in /usr/local/lib/python3.7/dist-packages (from spektral) (2.5.1)\n", | |
| "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from spektral) (0.22.2.post1)\n", | |
| "Requirement already satisfied: tensorflow>=2.1.0 in /usr/local/lib/python3.7/dist-packages (from spektral) (2.4.1)\n", | |
| "Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from spektral) (1.1.5)\n", | |
| "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from spektral) (1.19.5)\n", | |
| "Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from spektral) (1.0.1)\n", | |
| "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from spektral) (2.23.0)\n", | |
| "Requirement already satisfied: decorator<5,>=4.3 in /usr/local/lib/python3.7/dist-packages (from networkx->spektral) (4.4.2)\n", | |
| "Requirement already satisfied: termcolor~=1.1.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow>=2.1.0->spektral) (1.1.0)\n", | |
| "Requirement already satisfied: absl-py~=0.10 in /usr/local/lib/python3.7/dist-packages (from tensorflow>=2.1.0->spektral) (0.12.0)\n", | |
| "Requirement already satisfied: astunparse~=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorflow>=2.1.0->spektral) (1.6.3)\n", | |
| "Requirement already satisfied: opt-einsum~=3.3.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow>=2.1.0->spektral) (3.3.0)\n", | |
| "Requirement already satisfied: typing-extensions~=3.7.4 in /usr/local/lib/python3.7/dist-packages (from tensorflow>=2.1.0->spektral) (3.7.4.3)\n", | |
| "Requirement already satisfied: gast==0.3.3 in /usr/local/lib/python3.7/dist-packages (from tensorflow>=2.1.0->spektral) (0.3.3)\n", | |
| "Requirement already satisfied: protobuf>=3.9.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow>=2.1.0->spektral) (3.12.4)\n", | |
| "Requirement already satisfied: wheel~=0.35 in /usr/local/lib/python3.7/dist-packages (from tensorflow>=2.1.0->spektral) (0.36.2)\n", | |
| "Requirement already satisfied: grpcio~=1.32.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow>=2.1.0->spektral) (1.32.0)\n", | |
| "Requirement already satisfied: google-pasta~=0.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow>=2.1.0->spektral) (0.2.0)\n", | |
| "Requirement already satisfied: tensorflow-estimator<2.5.0,>=2.4.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow>=2.1.0->spektral) (2.4.0)\n", | |
| "Requirement already satisfied: six~=1.15.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow>=2.1.0->spektral) (1.15.0)\n", | |
| "Requirement already satisfied: keras-preprocessing~=1.1.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow>=2.1.0->spektral) (1.1.2)\n", | |
| "Requirement already satisfied: h5py~=2.10.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow>=2.1.0->spektral) (2.10.0)\n", | |
| "Requirement already satisfied: tensorboard~=2.4 in /usr/local/lib/python3.7/dist-packages (from tensorflow>=2.1.0->spektral) (2.4.1)\n", | |
| "Requirement already satisfied: flatbuffers~=1.12.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow>=2.1.0->spektral) (1.12)\n", | |
| "Requirement already satisfied: wrapt~=1.12.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow>=2.1.0->spektral) (1.12.1)\n", | |
| "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->spektral) (2018.9)\n", | |
| "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->spektral) (2.8.1)\n", | |
| "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->spektral) (1.24.3)\n", | |
| "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->spektral) (3.0.4)\n", | |
| "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->spektral) (2020.12.5)\n", | |
| "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->spektral) (2.10)\n", | |
| "Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from protobuf>=3.9.2->tensorflow>=2.1.0->spektral) (56.1.0)\n", | |
| "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.4->tensorflow>=2.1.0->spektral) (1.8.0)\n", | |
| "Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.4->tensorflow>=2.1.0->spektral) (2.0.0)\n", | |
| "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.4->tensorflow>=2.1.0->spektral) (0.4.4)\n", | |
| "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.4->tensorflow>=2.1.0->spektral) (3.3.4)\n", | |
| "Requirement already satisfied: google-auth<2,>=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.4->tensorflow>=2.1.0->spektral) (1.30.0)\n", | |
| "Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.4->tensorflow>=2.1.0->spektral) (1.3.0)\n", | |
| "Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from markdown>=2.6.8->tensorboard~=2.4->tensorflow>=2.1.0->spektral) (4.0.1)\n", | |
| "Requirement already satisfied: rsa<5,>=3.1.4; python_version >= \"3.6\" in /usr/local/lib/python3.7/dist-packages (from google-auth<2,>=1.6.3->tensorboard~=2.4->tensorflow>=2.1.0->spektral) (4.7.2)\n", | |
| "Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<2,>=1.6.3->tensorboard~=2.4->tensorflow>=2.1.0->spektral) (4.2.2)\n", | |
| "Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<2,>=1.6.3->tensorboard~=2.4->tensorflow>=2.1.0->spektral) (0.2.8)\n", | |
| "Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.4->tensorflow>=2.1.0->spektral) (3.1.0)\n", | |
| "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->markdown>=2.6.8->tensorboard~=2.4->tensorflow>=2.1.0->spektral) (3.4.1)\n", | |
| "Requirement already satisfied: pyasn1>=0.1.3 in /usr/local/lib/python3.7/dist-packages (from rsa<5,>=3.1.4; python_version >= \"3.6\"->google-auth<2,>=1.6.3->tensorboard~=2.4->tensorflow>=2.1.0->spektral) (0.4.8)\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "Fw3cynAtwQOE", | |
| "outputId": "ff98649c-08fd-4459-cc8f-3bc1d0877ef0" | |
| }, | |
| "source": [ | |
| "# Reading in the PROTEINS dataset\n", | |
| "from spektral.datasets import TUDataset\n", | |
| "\n", | |
| "# Spectral provides the TUDataset class, which contains benchmark datasets for graph classification\n", | |
| "data = TUDataset('PROTEINS')\n", | |
| "data" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "Successfully loaded PROTEINS.\n" | |
| ], | |
| "name": "stdout" | |
| }, | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "TUDataset(n_graphs=1113)" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 4 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "3AZCWahwwj7A" | |
| }, | |
| "source": [ | |
| "# Since we want to utilize the Spektral GCN layer, we want to follow the original paper for this method and perform some preprocessing:\n", | |
| "from spektral.transforms import GCNFilter\n", | |
| "\n", | |
| "data.apply(GCNFilter())" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "uDNI3sTt5mJQ" | |
| }, | |
| "source": [ | |
| "# Split our train and test data. This just splits based on the first 80%/second 20% which isn't entirely ideal, so we'll shuffle the data first.\n", | |
| "import numpy as np\n", | |
| "\n", | |
| "np.random.shuffle(data)\n", | |
| "split = int(0.8 * len(data))\n", | |
| "data_train, data_test = data[:split], data[split:]" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "Lwvnn79Ewm7M" | |
| }, | |
| "source": [ | |
| "# Spektral is built on top of Keras, so we can use the Keras functional API to build a model that first embeds,\n", | |
| "# then sums the nodes together (global pooling), then classifies the result with a dense softmax layer\n", | |
| "\n", | |
| "# First, let's import the necessary layers:\n", | |
| "from tensorflow.keras.models import Model\n", | |
| "from tensorflow.keras.layers import Dense, Dropout\n", | |
| "from spektral.layers import GCNConv, GlobalSumPool" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "PrTSQXkZyW78" | |
| }, | |
| "source": [ | |
| "# Now, we can use model subclassing to define our model:\n", | |
| "\n", | |
| "class ProteinsGNN(Model):\n", | |
| " \n", | |
| " def __init__(self, n_hidden, n_labels):\n", | |
| " super().__init__()\n", | |
| " # Define our GCN layer with our n_hidden layers\n", | |
| " self.graph_conv = GCNConv(n_hidden)\n", | |
| " # Define our global pooling layer\n", | |
| " self.pool = GlobalSumPool()\n", | |
| " # Define our dropout layer, initialize dropout freq. to .5 (50%)\n", | |
| " self.dropout = Dropout(0.5)\n", | |
| " # Define our Dense layer, with softmax activation function\n", | |
| " self.dense = Dense(n_labels, 'softmax')\n", | |
| "\n", | |
| " # Define class method to call model on input\n", | |
| " def call(self, inputs):\n", | |
| " out = self.graph_conv(inputs)\n", | |
| " out = self.dropout(out)\n", | |
| " out = self.pool(out)\n", | |
| " out = self.dense(out)\n", | |
| "\n", | |
| " return out" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "6CWW1urKzRrU" | |
| }, | |
| "source": [ | |
| "# Instantiate our model for training\n", | |
| "model = ProteinsGNN(32, data.n_labels)" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "WG2YY6CMzf3I" | |
| }, | |
| "source": [ | |
| "# Compile model with our optimizer (adam) and loss function\n", | |
| "model.compile('adam', 'categorical_crossentropy')" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "qw1hlui8zpYg" | |
| }, | |
| "source": [ | |
| "# Here's the trick - we can't just call Keras' fit() method on this model.\n", | |
| "# Instead, we have to use Loaders, which Spektral walks us through. Loaders create mini-batches by iterating over the graph\n", | |
| "# Since we're using Spektral for an experiment, for our first trial we'll use the recommended loader in the getting started tutorial\n", | |
| "\n", | |
| "# TODO: read up on modes and try other loaders later\n", | |
| "from spektral.data import BatchLoader\n", | |
| "\n", | |
| "loader = BatchLoader(data_train, batch_size=32)" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "AmJ70FZy0Jgw", | |
| "outputId": "2313fc8d-c1f9-45a3-f4ed-0493eb0d174b" | |
| }, | |
| "source": [ | |
| "# Now we can train! We don't need to specify a batch size, since our loader is basically a generator\n", | |
| "# But we do need to specify the steps_per_epoch parameter\n", | |
| "\n", | |
| "model.fit(loader.load(), steps_per_epoch=loader.steps_per_epoch, epochs=10)" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 1/10\n", | |
| "28/28 [==============================] - 2s 22ms/step - loss: 90.4481\n", | |
| "Epoch 2/10\n", | |
| "28/28 [==============================] - 1s 22ms/step - loss: 12.5490\n", | |
| "Epoch 3/10\n", | |
| "28/28 [==============================] - 1s 19ms/step - loss: 6.0662\n", | |
| "Epoch 4/10\n", | |
| "28/28 [==============================] - 1s 22ms/step - loss: 5.2109\n", | |
| "Epoch 5/10\n", | |
| "28/28 [==============================] - 1s 20ms/step - loss: 5.5455\n", | |
| "Epoch 6/10\n", | |
| "28/28 [==============================] - 1s 21ms/step - loss: 5.9871\n", | |
| "Epoch 7/10\n", | |
| "28/28 [==============================] - 1s 19ms/step - loss: 5.7645\n", | |
| "Epoch 8/10\n", | |
| "28/28 [==============================] - 1s 21ms/step - loss: 5.5593\n", | |
| "Epoch 9/10\n", | |
| "28/28 [==============================] - 1s 21ms/step - loss: 5.2851\n", | |
| "Epoch 10/10\n", | |
| "28/28 [==============================] - 1s 18ms/step - loss: 4.8937\n" | |
| ], | |
| "name": "stdout" | |
| }, | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "<tensorflow.python.keras.callbacks.History at 0x7fde717c7990>" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 12 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "9GOzdAb01Ei4" | |
| }, | |
| "source": [ | |
| "# To evaluate, let's instantiate another loader to test\n", | |
| "\n", | |
| "test_loader = BatchLoader(data_test, batch_size=32)" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "3wZS1SEK6H4o", | |
| "outputId": "adfd9400-9dca-451f-a726-711039d90345" | |
| }, | |
| "source": [ | |
| "# And feed it to our model by calling .load()\n", | |
| "\n", | |
| "loss = model.evaluate(loader.load(), steps=loader.steps_per_epoch)\n", | |
| "\n", | |
| "print('Test loss: {}'.format(loss))" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "28/28 [==============================] - 1s 13ms/step - loss: 3.8737\n", | |
| "Test loss: 3.87371563911438\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "Z5Lud8oKFzGk" | |
| }, | |
| "source": [ | |
| "## PyTorch Geometric GCN\n", | |
| "\n", | |
| "Pytorch Geometric provides [GCN layers](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html) based on Kipf & Welling's original paper: [\"Semi-Supervised Classification with Graph Convolutional Networks\"](https://arxiv.org/abs/1609.02907) on which I've based the bulk of my research and write-ups.\n", | |
| "\n", | |
| "While my original goal was to use my [original experiment](https://colab.research.google.com/drive/1NUQgUdrdvIddewdQyGEpas_mPaFzC8-e?usp=sharing) (based off of [this](https://towardsdatascience.com/understanding-graph-convolutional-networks-for-node-classification-a2bfdb7aba7b) resource) to build this from scratch, I ran into difficulties trying to embed and classify such a large dataset, specifically with Colab RAM allowances.\n", | |
| "\n", | |
| "For this reason, I sought out different methods and found that this problem had already been solved, and for purposes of time and demonstration chose to delve into Pytorch Geometric rather than invent the wheel. \n", | |
| "\n", | |
| "In order to successfully learn to implement this approach with this library, I relied on the Pytorch Geometric [documentation](https://pytorch-geometric.readthedocs.io/en/latest/index.html) as well as [this notebook](https://colab.research.google.com/drive/1I8a0DfQ3fI7Njc62__mVXUlcAleUclnb?usp=sharing) written by [email protected]. \n", | |
| "\n", | |
| "I would like to extend thanks and all due credit to these authors, as this work and research would not be possible without them. Further credits and citations can be found in the [README](https://github.com/sidneyarcidiacono/graph-convolutional-networks) of this repository." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "NbFimfvQTMuz", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "outputId": "fd864d8c-f97b-4797-81df-b89cdca82993" | |
| }, | |
| "source": [ | |
| "# Install required packages.\n", | |
| "!pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html\n", | |
| "!pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html\n", | |
| "!pip install -q torch-geometric" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "\u001b[K |████████████████████████████████| 2.6MB 354kB/s \n", | |
| "\u001b[K |████████████████████████████████| 1.5MB 313kB/s \n", | |
| "\u001b[K |████████████████████████████████| 215kB 7.7MB/s \n", | |
| "\u001b[K |████████████████████████████████| 235kB 12.6MB/s \n", | |
| "\u001b[K |████████████████████████████████| 2.2MB 16.2MB/s \n", | |
| "\u001b[K |████████████████████████████████| 51kB 8.4MB/s \n", | |
| "\u001b[?25h Building wheel for torch-geometric (setup.py) ... \u001b[?25l\u001b[?25hdone\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "GXT8XfFcQZU5", | |
| "outputId": "93d25903-f6df-4586-fd2d-03f36aa3b9c4" | |
| }, | |
| "source": [ | |
| "import torch\n", | |
| "from torch_geometric.datasets import TUDataset\n", | |
| "\n", | |
| "# Like Spektral, pytorch geometric provides us with benchmark TUDatasets\n", | |
| "dataset = TUDataset(root='data/TUDataset', name='PROTEINS')\n", | |
| "\n", | |
| "# Let's take a look at our data. We'll look at dataset (all data) and data (our first graph):\n", | |
| "\n", | |
| "data = dataset[0] # Get the first graph object.\n", | |
| "\n", | |
| "print()\n", | |
| "print(f'Dataset: {dataset}:')\n", | |
| "print('====================')\n", | |
| "# How many graphs?\n", | |
| "print(f'Number of graphs: {len(dataset)}')\n", | |
| "# How many features?\n", | |
| "print(f'Number of features: {dataset.num_features}')\n", | |
| "# Now, in our first graph, how many edges?\n", | |
| "print(f'Number of edges: {data.num_edges}')\n", | |
| "# Average node degree?\n", | |
| "print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')\n", | |
| "# Do we have isolated nodes?\n", | |
| "print(f'Contains isolated nodes: {data.contains_isolated_nodes()}')\n", | |
| "# Do we contain self-loops?\n", | |
| "print(f'Contains self-loops: {data.contains_self_loops()}')\n", | |
| "# Is this an undirected graph?\n", | |
| "print(f'Is undirected: {data.is_undirected()}')" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "\n", | |
| "Dataset: PROTEINS(1113):\n", | |
| "====================\n", | |
| "Number of graphs: 1113\n", | |
| "Number of features: 3\n", | |
| "Number of edges: 162\n", | |
| "Average node degree: 3.86\n", | |
| "Contains isolated nodes: False\n", | |
| "Contains self-loops: False\n", | |
| "Is undirected: True\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "2-pZeqpXQcIt", | |
| "outputId": "62cd278c-59a4-4c74-e630-c1f955848d1f" | |
| }, | |
| "source": [ | |
| "# Now, we need to perform our train/test split.\n", | |
| "# We create a seed, and then shuffle our data\n", | |
| "torch.manual_seed(12345)\n", | |
| "dataset = dataset.shuffle()\n", | |
| "\n", | |
| "# Once it's shuffled, we slice the data to split\n", | |
| "train_dataset = dataset[150:-150]\n", | |
| "test_dataset = dataset[0:150]\n", | |
| "\n", | |
| "# Take a look at the training versus test graphs\n", | |
| "print(f'Number of training graphs: {len(train_dataset)}')\n", | |
| "print(f'Number of test graphs: {len(test_dataset)}')" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "Number of training graphs: 813\n", | |
| "Number of test graphs: 150\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "-ERyFSsCQAcQ", | |
| "outputId": "a6d8c785-e844-4495-e266-31dcb5464415" | |
| }, | |
| "source": [ | |
| "# Import DataLoader for batching\n", | |
| "from torch_geometric.data import DataLoader\n", | |
| "\n", | |
| "# our DataLoader creates diagonal adjacency matrices, and concatenates features\n", | |
| "# and target matrices in the node dimension. This allows differing numbers of nodes and edges \n", | |
| "# over examples in one batch. (from pytorch geometric docs)\n", | |
| "train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)\n", | |
| "test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)\n", | |
| "\n", | |
| "# Take a look at the output to understand this further:\n", | |
| "for step, data in enumerate(train_loader):\n", | |
| " print(f'Step {step + 1}:')\n", | |
| " print('=======')\n", | |
| " print(f'Number of graphs in the current batch: {data.num_graphs}')\n", | |
| " print(data)\n", | |
| " print()" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "Step 1:\n", | |
| "=======\n", | |
| "Number of graphs in the current batch: 64\n", | |
| "Batch(batch=[2506], edge_index=[2, 9562], ptr=[65], x=[2506, 3], y=[64])\n", | |
| "\n", | |
| "Step 2:\n", | |
| "=======\n", | |
| "Number of graphs in the current batch: 64\n", | |
| "Batch(batch=[2271], edge_index=[2, 8220], ptr=[65], x=[2271, 3], y=[64])\n", | |
| "\n", | |
| "Step 3:\n", | |
| "=======\n", | |
| "Number of graphs in the current batch: 64\n", | |
| "Batch(batch=[2739], edge_index=[2, 10702], ptr=[65], x=[2739, 3], y=[64])\n", | |
| "\n", | |
| "Step 4:\n", | |
| "=======\n", | |
| "Number of graphs in the current batch: 64\n", | |
| "Batch(batch=[3546], edge_index=[2, 12958], ptr=[65], x=[3546, 3], y=[64])\n", | |
| "\n", | |
| "Step 5:\n", | |
| "=======\n", | |
| "Number of graphs in the current batch: 64\n", | |
| "Batch(batch=[2231], edge_index=[2, 8432], ptr=[65], x=[2231, 3], y=[64])\n", | |
| "\n", | |
| "Step 6:\n", | |
| "=======\n", | |
| "Number of graphs in the current batch: 64\n", | |
| "Batch(batch=[2822], edge_index=[2, 10366], ptr=[65], x=[2822, 3], y=[64])\n", | |
| "\n", | |
| "Step 7:\n", | |
| "=======\n", | |
| "Number of graphs in the current batch: 64\n", | |
| "Batch(batch=[2491], edge_index=[2, 9146], ptr=[65], x=[2491, 3], y=[64])\n", | |
| "\n", | |
| "Step 8:\n", | |
| "=======\n", | |
| "Number of graphs in the current batch: 64\n", | |
| "Batch(batch=[2088], edge_index=[2, 8114], ptr=[65], x=[2088, 3], y=[64])\n", | |
| "\n", | |
| "Step 9:\n", | |
| "=======\n", | |
| "Number of graphs in the current batch: 64\n", | |
| "Batch(batch=[2220], edge_index=[2, 8314], ptr=[65], x=[2220, 3], y=[64])\n", | |
| "\n", | |
| "Step 10:\n", | |
| "=======\n", | |
| "Number of graphs in the current batch: 64\n", | |
| "Batch(batch=[2356], edge_index=[2, 8736], ptr=[65], x=[2356, 3], y=[64])\n", | |
| "\n", | |
| "Step 11:\n", | |
| "=======\n", | |
| "Number of graphs in the current batch: 64\n", | |
| "Batch(batch=[2914], edge_index=[2, 10848], ptr=[65], x=[2914, 3], y=[64])\n", | |
| "\n", | |
| "Step 12:\n", | |
| "=======\n", | |
| "Number of graphs in the current batch: 64\n", | |
| "Batch(batch=[2166], edge_index=[2, 7972], ptr=[65], x=[2166, 3], y=[64])\n", | |
| "\n", | |
| "Step 13:\n", | |
| "=======\n", | |
| "Number of graphs in the current batch: 45\n", | |
| "Batch(batch=[1763], edge_index=[2, 6524], ptr=[46], x=[1763, 3], y=[45])\n", | |
| "\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "Buu47UjnOJtW", | |
| "outputId": "97bdb1bb-f958-4e48-ddea-24093c3c2bd9" | |
| }, | |
| "source": [ | |
| "# Import everything we need to build our network:\n", | |
| "from torch.nn import Linear\n", | |
| "import torch.nn.functional as F\n", | |
| "from torch_geometric.nn import GCNConv\n", | |
| "from torch_geometric.nn import global_mean_pool\n", | |
| "\n", | |
| "# Define our GCN class as a pytorch Module\n", | |
| "class GCN(torch.nn.Module):\n", | |
| " def __init__(self, hidden_channels):\n", | |
| " super(GCN, self).__init__()\n", | |
| " # We inherit from pytorch geometric's GCN class, and we initialize three layers\n", | |
| " self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)\n", | |
| " self.conv2 = GCNConv(hidden_channels, hidden_channels)\n", | |
| " self.conv3 = GCNConv(hidden_channels, hidden_channels)\n", | |
| " # Our final linear layer will define our output\n", | |
| " self.lin = Linear(hidden_channels, dataset.num_classes)\n", | |
| "\n", | |
| " def forward(self, x, edge_index, batch):\n", | |
| " # 1. Obtain node embeddings \n", | |
| " x = self.conv1(x, edge_index)\n", | |
| " x = x.relu()\n", | |
| " x = self.conv2(x, edge_index)\n", | |
| " x = x.relu()\n", | |
| " x = self.conv3(x, edge_index)\n", | |
| "\n", | |
| " # 2. Readout layer\n", | |
| " x = global_mean_pool(x, batch) # [batch_size, hidden_channels]\n", | |
| "\n", | |
| " # 3. Apply a final classifier\n", | |
| " x = F.dropout(x, p=0.5, training=self.training)\n", | |
| " x = self.lin(x)\n", | |
| " return x\n", | |
| "\n", | |
| "model = GCN(hidden_channels=64)\n", | |
| "print(model)" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "GCN(\n", | |
| " (conv1): GCNConv(3, 64)\n", | |
| " (conv2): GCNConv(64, 64)\n", | |
| " (conv3): GCNConv(64, 64)\n", | |
| " (lin): Linear(in_features=64, out_features=2, bias=True)\n", | |
| ")\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "tpON7LPyOYJB", | |
| "outputId": "00af1b56-2b42-43e5-f336-fd2696b3b563" | |
| }, | |
| "source": [ | |
| "# Initialize our model from our GCN class:\n", | |
| "model = GCN(hidden_channels=64)\n", | |
| "# Set our optimizer (adam)\n", | |
| "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n", | |
| "# Define our loss function\n", | |
| "criterion = torch.nn.CrossEntropyLoss()\n", | |
| "\n", | |
| "# Initialize our train function\n", | |
| "def train():\n", | |
| " model.train()\n", | |
| "\n", | |
| " for data in train_loader: # Iterate in batches over the training dataset.\n", | |
| " out = model(data.x, data.edge_index, data.batch) # Perform a single forward pass.\n", | |
| " loss = criterion(out, data.y) # Compute the loss.\n", | |
| " loss.backward() # Derive gradients.\n", | |
| " optimizer.step() # Update parameters based on gradients.\n", | |
| " optimizer.zero_grad() # Clear gradients.\n", | |
| "\n", | |
| "# Define our test function\n", | |
| "def test(loader):\n", | |
| " model.eval()\n", | |
| "\n", | |
| " correct = 0\n", | |
| " for data in loader: # Iterate in batches over the training/test dataset.\n", | |
| " out = model(data.x, data.edge_index, data.batch) \n", | |
| " pred = out.argmax(dim=1) # Use the class with highest probability.\n", | |
| " correct += int((pred == data.y).sum()) # Check against ground-truth labels.\n", | |
| " return correct / len(loader.dataset) # Derive ratio of correct predictions.\n", | |
| "\n", | |
| "# Run for 200 epochs (range is exclusive in the upper bound)\n", | |
| "for epoch in range(1, 201):\n", | |
| " train()\n", | |
| " train_acc = test(train_loader)\n", | |
| " test_acc = test(test_loader)\n", | |
| " print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch: 001, Train Acc: 0.6138, Test Acc: 0.5533\n", | |
| "Epoch: 002, Train Acc: 0.6335, Test Acc: 0.5933\n", | |
| "Epoch: 003, Train Acc: 0.6950, Test Acc: 0.6467\n", | |
| "Epoch: 004, Train Acc: 0.6950, Test Acc: 0.6667\n", | |
| "Epoch: 005, Train Acc: 0.7085, Test Acc: 0.6600\n", | |
| "Epoch: 006, Train Acc: 0.6765, Test Acc: 0.6400\n", | |
| "Epoch: 007, Train Acc: 0.7245, Test Acc: 0.6667\n", | |
| "Epoch: 008, Train Acc: 0.6839, Test Acc: 0.6600\n", | |
| "Epoch: 009, Train Acc: 0.6986, Test Acc: 0.6400\n", | |
| "Epoch: 010, Train Acc: 0.6790, Test Acc: 0.6533\n", | |
| "Epoch: 011, Train Acc: 0.6925, Test Acc: 0.6400\n", | |
| "Epoch: 012, Train Acc: 0.7060, Test Acc: 0.6533\n", | |
| "Epoch: 013, Train Acc: 0.7060, Test Acc: 0.6667\n", | |
| "Epoch: 014, Train Acc: 0.6851, Test Acc: 0.6333\n", | |
| "Epoch: 015, Train Acc: 0.7134, Test Acc: 0.6800\n", | |
| "Epoch: 016, Train Acc: 0.7245, Test Acc: 0.6733\n", | |
| "Epoch: 017, Train Acc: 0.7245, Test Acc: 0.7000\n", | |
| "Epoch: 018, Train Acc: 0.6482, Test Acc: 0.5933\n", | |
| "Epoch: 019, Train Acc: 0.6802, Test Acc: 0.6600\n", | |
| "Epoch: 020, Train Acc: 0.6962, Test Acc: 0.6400\n", | |
| "Epoch: 021, Train Acc: 0.6691, Test Acc: 0.6000\n", | |
| "Epoch: 022, Train Acc: 0.7269, Test Acc: 0.6733\n", | |
| "Epoch: 023, Train Acc: 0.7159, Test Acc: 0.6667\n", | |
| "Epoch: 024, Train Acc: 0.7122, Test Acc: 0.6733\n", | |
| "Epoch: 025, Train Acc: 0.7060, Test Acc: 0.6867\n", | |
| "Epoch: 026, Train Acc: 0.6753, Test Acc: 0.6267\n", | |
| "Epoch: 027, Train Acc: 0.6814, Test Acc: 0.6467\n", | |
| "Epoch: 028, Train Acc: 0.7085, Test Acc: 0.6600\n", | |
| "Epoch: 029, Train Acc: 0.6827, Test Acc: 0.6333\n", | |
| "Epoch: 030, Train Acc: 0.6986, Test Acc: 0.6533\n", | |
| "Epoch: 031, Train Acc: 0.7097, Test Acc: 0.6733\n", | |
| "Epoch: 032, Train Acc: 0.7196, Test Acc: 0.6800\n", | |
| "Epoch: 033, Train Acc: 0.6974, Test Acc: 0.6467\n", | |
| "Epoch: 034, Train Acc: 0.7023, Test Acc: 0.6733\n", | |
| "Epoch: 035, Train Acc: 0.7048, Test Acc: 0.6600\n", | |
| "Epoch: 036, Train Acc: 0.7122, Test Acc: 0.6600\n", | |
| "Epoch: 037, Train Acc: 0.7073, Test Acc: 0.6733\n", | |
| "Epoch: 038, Train Acc: 0.7269, Test Acc: 0.6867\n", | |
| "Epoch: 039, Train Acc: 0.6642, Test Acc: 0.6000\n", | |
| "Epoch: 040, Train Acc: 0.6999, Test Acc: 0.6533\n", | |
| "Epoch: 041, Train Acc: 0.7122, Test Acc: 0.6733\n", | |
| "Epoch: 042, Train Acc: 0.7048, Test Acc: 0.6533\n", | |
| "Epoch: 043, Train Acc: 0.7282, Test Acc: 0.6800\n", | |
| "Epoch: 044, Train Acc: 0.6445, Test Acc: 0.6000\n", | |
| "Epoch: 045, Train Acc: 0.7097, Test Acc: 0.6800\n", | |
| "Epoch: 046, Train Acc: 0.7109, Test Acc: 0.6667\n", | |
| "Epoch: 047, Train Acc: 0.7109, Test Acc: 0.6600\n", | |
| "Epoch: 048, Train Acc: 0.7331, Test Acc: 0.6733\n", | |
| "Epoch: 049, Train Acc: 0.7220, Test Acc: 0.6733\n", | |
| "Epoch: 050, Train Acc: 0.7159, Test Acc: 0.6867\n", | |
| "Epoch: 051, Train Acc: 0.7048, Test Acc: 0.6733\n", | |
| "Epoch: 052, Train Acc: 0.6888, Test Acc: 0.6733\n", | |
| "Epoch: 053, Train Acc: 0.7331, Test Acc: 0.6800\n", | |
| "Epoch: 054, Train Acc: 0.7134, Test Acc: 0.6733\n", | |
| "Epoch: 055, Train Acc: 0.7196, Test Acc: 0.6867\n", | |
| "Epoch: 056, Train Acc: 0.6937, Test Acc: 0.6600\n", | |
| "Epoch: 057, Train Acc: 0.7109, Test Acc: 0.6667\n", | |
| "Epoch: 058, Train Acc: 0.7134, Test Acc: 0.6733\n", | |
| "Epoch: 059, Train Acc: 0.7269, Test Acc: 0.6867\n", | |
| "Epoch: 060, Train Acc: 0.7085, Test Acc: 0.6600\n", | |
| "Epoch: 061, Train Acc: 0.7048, Test Acc: 0.6733\n", | |
| "Epoch: 062, Train Acc: 0.7331, Test Acc: 0.6933\n", | |
| "Epoch: 063, Train Acc: 0.7196, Test Acc: 0.6600\n", | |
| "Epoch: 064, Train Acc: 0.7159, Test Acc: 0.6867\n", | |
| "Epoch: 065, Train Acc: 0.7355, Test Acc: 0.6933\n", | |
| "Epoch: 066, Train Acc: 0.7097, Test Acc: 0.6600\n", | |
| "Epoch: 067, Train Acc: 0.6999, Test Acc: 0.6733\n", | |
| "Epoch: 068, Train Acc: 0.6728, Test Acc: 0.6467\n", | |
| "Epoch: 069, Train Acc: 0.7282, Test Acc: 0.6600\n", | |
| "Epoch: 070, Train Acc: 0.7146, Test Acc: 0.6867\n", | |
| "Epoch: 071, Train Acc: 0.6839, Test Acc: 0.6600\n", | |
| "Epoch: 072, Train Acc: 0.7306, Test Acc: 0.6800\n", | |
| "Epoch: 073, Train Acc: 0.7159, Test Acc: 0.6600\n", | |
| "Epoch: 074, Train Acc: 0.7122, Test Acc: 0.6733\n", | |
| "Epoch: 075, Train Acc: 0.7183, Test Acc: 0.6667\n", | |
| "Epoch: 076, Train Acc: 0.6950, Test Acc: 0.6200\n", | |
| "Epoch: 077, Train Acc: 0.6974, Test Acc: 0.6733\n", | |
| "Epoch: 078, Train Acc: 0.7159, Test Acc: 0.6600\n", | |
| "Epoch: 079, Train Acc: 0.7232, Test Acc: 0.6667\n", | |
| "Epoch: 080, Train Acc: 0.7269, Test Acc: 0.6867\n", | |
| "Epoch: 081, Train Acc: 0.6900, Test Acc: 0.6467\n", | |
| "Epoch: 082, Train Acc: 0.6777, Test Acc: 0.6733\n", | |
| "Epoch: 083, Train Acc: 0.6863, Test Acc: 0.6467\n", | |
| "Epoch: 084, Train Acc: 0.7159, Test Acc: 0.6733\n", | |
| "Epoch: 085, Train Acc: 0.7048, Test Acc: 0.6400\n", | |
| "Epoch: 086, Train Acc: 0.7269, Test Acc: 0.6600\n", | |
| "Epoch: 087, Train Acc: 0.7306, Test Acc: 0.6867\n", | |
| "Epoch: 088, Train Acc: 0.7036, Test Acc: 0.6667\n", | |
| "Epoch: 089, Train Acc: 0.7232, Test Acc: 0.6867\n", | |
| "Epoch: 090, Train Acc: 0.6740, Test Acc: 0.5933\n", | |
| "Epoch: 091, Train Acc: 0.7282, Test Acc: 0.6867\n", | |
| "Epoch: 092, Train Acc: 0.6986, Test Acc: 0.6667\n", | |
| "Epoch: 093, Train Acc: 0.7085, Test Acc: 0.6600\n", | |
| "Epoch: 094, Train Acc: 0.7109, Test Acc: 0.6733\n", | |
| "Epoch: 095, Train Acc: 0.6925, Test Acc: 0.6533\n", | |
| "Epoch: 096, Train Acc: 0.7306, Test Acc: 0.6800\n", | |
| "Epoch: 097, Train Acc: 0.6913, Test Acc: 0.6600\n", | |
| "Epoch: 098, Train Acc: 0.7048, Test Acc: 0.6733\n", | |
| "Epoch: 099, Train Acc: 0.7294, Test Acc: 0.6867\n", | |
| "Epoch: 100, Train Acc: 0.7208, Test Acc: 0.6667\n", | |
| "Epoch: 101, Train Acc: 0.7122, Test Acc: 0.6600\n", | |
| "Epoch: 102, Train Acc: 0.7294, Test Acc: 0.6733\n", | |
| "Epoch: 103, Train Acc: 0.7429, Test Acc: 0.7000\n", | |
| "Epoch: 104, Train Acc: 0.7232, Test Acc: 0.7000\n", | |
| "Epoch: 105, Train Acc: 0.7208, Test Acc: 0.6467\n", | |
| "Epoch: 106, Train Acc: 0.7405, Test Acc: 0.6867\n", | |
| "Epoch: 107, Train Acc: 0.7282, Test Acc: 0.6800\n", | |
| "Epoch: 108, Train Acc: 0.6851, Test Acc: 0.6533\n", | |
| "Epoch: 109, Train Acc: 0.7183, Test Acc: 0.6867\n", | |
| "Epoch: 110, Train Acc: 0.7159, Test Acc: 0.6667\n", | |
| "Epoch: 111, Train Acc: 0.7196, Test Acc: 0.6533\n", | |
| "Epoch: 112, Train Acc: 0.7183, Test Acc: 0.6800\n", | |
| "Epoch: 113, Train Acc: 0.7294, Test Acc: 0.6733\n", | |
| "Epoch: 114, Train Acc: 0.6986, Test Acc: 0.6467\n", | |
| "Epoch: 115, Train Acc: 0.7405, Test Acc: 0.6933\n", | |
| "Epoch: 116, Train Acc: 0.6962, Test Acc: 0.6600\n", | |
| "Epoch: 117, Train Acc: 0.7405, Test Acc: 0.7000\n", | |
| "Epoch: 118, Train Acc: 0.7294, Test Acc: 0.6867\n", | |
| "Epoch: 119, Train Acc: 0.7429, Test Acc: 0.6933\n", | |
| "Epoch: 120, Train Acc: 0.7232, Test Acc: 0.6467\n", | |
| "Epoch: 121, Train Acc: 0.7405, Test Acc: 0.6867\n", | |
| "Epoch: 122, Train Acc: 0.6950, Test Acc: 0.6467\n", | |
| "Epoch: 123, Train Acc: 0.7122, Test Acc: 0.6600\n", | |
| "Epoch: 124, Train Acc: 0.7208, Test Acc: 0.6933\n", | |
| "Epoch: 125, Train Acc: 0.7380, Test Acc: 0.6933\n", | |
| "Epoch: 126, Train Acc: 0.7405, Test Acc: 0.6867\n", | |
| "Epoch: 127, Train Acc: 0.6950, Test Acc: 0.6533\n", | |
| "Epoch: 128, Train Acc: 0.6999, Test Acc: 0.6467\n", | |
| "Epoch: 129, Train Acc: 0.6617, Test Acc: 0.5933\n", | |
| "Epoch: 130, Train Acc: 0.6974, Test Acc: 0.6467\n", | |
| "Epoch: 131, Train Acc: 0.7232, Test Acc: 0.6533\n", | |
| "Epoch: 132, Train Acc: 0.7085, Test Acc: 0.6533\n", | |
| "Epoch: 133, Train Acc: 0.7294, Test Acc: 0.6800\n", | |
| "Epoch: 134, Train Acc: 0.7392, Test Acc: 0.6933\n", | |
| "Epoch: 135, Train Acc: 0.7134, Test Acc: 0.6733\n", | |
| "Epoch: 136, Train Acc: 0.7343, Test Acc: 0.6800\n", | |
| "Epoch: 137, Train Acc: 0.6740, Test Acc: 0.5867\n", | |
| "Epoch: 138, Train Acc: 0.7294, Test Acc: 0.6933\n", | |
| "Epoch: 139, Train Acc: 0.7085, Test Acc: 0.6467\n", | |
| "Epoch: 140, Train Acc: 0.7331, Test Acc: 0.6667\n", | |
| "Epoch: 141, Train Acc: 0.7109, Test Acc: 0.6333\n", | |
| "Epoch: 142, Train Acc: 0.7331, Test Acc: 0.6733\n", | |
| "Epoch: 143, Train Acc: 0.7159, Test Acc: 0.6400\n", | |
| "Epoch: 144, Train Acc: 0.7392, Test Acc: 0.6933\n", | |
| "Epoch: 145, Train Acc: 0.6974, Test Acc: 0.6067\n", | |
| "Epoch: 146, Train Acc: 0.7232, Test Acc: 0.6600\n", | |
| "Epoch: 147, Train Acc: 0.7146, Test Acc: 0.6533\n", | |
| "Epoch: 148, Train Acc: 0.7134, Test Acc: 0.6467\n", | |
| "Epoch: 149, Train Acc: 0.7257, Test Acc: 0.6467\n", | |
| "Epoch: 150, Train Acc: 0.7442, Test Acc: 0.6800\n", | |
| "Epoch: 151, Train Acc: 0.7085, Test Acc: 0.6467\n", | |
| "Epoch: 152, Train Acc: 0.7011, Test Acc: 0.6133\n", | |
| "Epoch: 153, Train Acc: 0.7392, Test Acc: 0.6667\n", | |
| "Epoch: 154, Train Acc: 0.7245, Test Acc: 0.6800\n", | |
| "Epoch: 155, Train Acc: 0.7269, Test Acc: 0.6667\n", | |
| "Epoch: 156, Train Acc: 0.7257, Test Acc: 0.6600\n", | |
| "Epoch: 157, Train Acc: 0.6876, Test Acc: 0.6467\n", | |
| "Epoch: 158, Train Acc: 0.7417, Test Acc: 0.6867\n", | |
| "Epoch: 159, Train Acc: 0.7355, Test Acc: 0.6667\n", | |
| "Epoch: 160, Train Acc: 0.7355, Test Acc: 0.6533\n", | |
| "Epoch: 161, Train Acc: 0.7306, Test Acc: 0.6867\n", | |
| "Epoch: 162, Train Acc: 0.7036, Test Acc: 0.6667\n", | |
| "Epoch: 163, Train Acc: 0.7368, Test Acc: 0.6600\n", | |
| "Epoch: 164, Train Acc: 0.7294, Test Acc: 0.6467\n", | |
| "Epoch: 165, Train Acc: 0.7282, Test Acc: 0.6533\n", | |
| "Epoch: 166, Train Acc: 0.7257, Test Acc: 0.6467\n", | |
| "Epoch: 167, Train Acc: 0.7060, Test Acc: 0.6600\n", | |
| "Epoch: 168, Train Acc: 0.7306, Test Acc: 0.6533\n", | |
| "Epoch: 169, Train Acc: 0.7269, Test Acc: 0.6867\n", | |
| "Epoch: 170, Train Acc: 0.7159, Test Acc: 0.6667\n", | |
| "Epoch: 171, Train Acc: 0.7208, Test Acc: 0.6600\n", | |
| "Epoch: 172, Train Acc: 0.7208, Test Acc: 0.6733\n", | |
| "Epoch: 173, Train Acc: 0.7245, Test Acc: 0.6600\n", | |
| "Epoch: 174, Train Acc: 0.7282, Test Acc: 0.6600\n", | |
| "Epoch: 175, Train Acc: 0.7220, Test Acc: 0.6400\n", | |
| "Epoch: 176, Train Acc: 0.6950, Test Acc: 0.5733\n", | |
| "Epoch: 177, Train Acc: 0.7294, Test Acc: 0.6600\n", | |
| "Epoch: 178, Train Acc: 0.7331, Test Acc: 0.6533\n", | |
| "Epoch: 179, Train Acc: 0.7380, Test Acc: 0.6800\n", | |
| "Epoch: 180, Train Acc: 0.7343, Test Acc: 0.6867\n", | |
| "Epoch: 181, Train Acc: 0.7392, Test Acc: 0.6933\n", | |
| "Epoch: 182, Train Acc: 0.7183, Test Acc: 0.6600\n", | |
| "Epoch: 183, Train Acc: 0.7294, Test Acc: 0.6467\n", | |
| "Epoch: 184, Train Acc: 0.6900, Test Acc: 0.6467\n", | |
| "Epoch: 185, Train Acc: 0.7196, Test Acc: 0.6333\n", | |
| "Epoch: 186, Train Acc: 0.7392, Test Acc: 0.6933\n", | |
| "Epoch: 187, Train Acc: 0.7208, Test Acc: 0.6533\n", | |
| "Epoch: 188, Train Acc: 0.6986, Test Acc: 0.6733\n", | |
| "Epoch: 189, Train Acc: 0.7257, Test Acc: 0.6333\n", | |
| "Epoch: 190, Train Acc: 0.7257, Test Acc: 0.6733\n", | |
| "Epoch: 191, Train Acc: 0.7269, Test Acc: 0.6400\n", | |
| "Epoch: 192, Train Acc: 0.7380, Test Acc: 0.6733\n", | |
| "Epoch: 193, Train Acc: 0.7306, Test Acc: 0.6733\n", | |
| "Epoch: 194, Train Acc: 0.7306, Test Acc: 0.6467\n", | |
| "Epoch: 195, Train Acc: 0.7331, Test Acc: 0.6800\n", | |
| "Epoch: 196, Train Acc: 0.7257, Test Acc: 0.6400\n", | |
| "Epoch: 197, Train Acc: 0.7146, Test Acc: 0.6467\n", | |
| "Epoch: 198, Train Acc: 0.7232, Test Acc: 0.6467\n", | |
| "Epoch: 199, Train Acc: 0.7134, Test Acc: 0.6267\n", | |
| "Epoch: 200, Train Acc: 0.7134, Test Acc: 0.6267\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment