Created
March 23, 2023 03:48
-
-
Save gt3/01782a443d7245bc74091019e4b776cf to your computer and use it in GitHub Desktop.
Understanding Code with Graph Neural Networks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "XiP6-MBFmKAB" | |
}, | |
"source": [ | |
"# 0. Introduction\n", | |
"\n", | |
"In this notebook, we take a closer look at how to apply Graph Neural Networks (GNNs) to the task of graph-level prediction on the ogbg-code2 dataset. The prediction task is defined as: \n", | |
"Given a graph representation of a program, specifically the body of a method, generate the set of tokens that form the method's name.\n", | |
"\n", | |
"\n", | |
"\n", | |
"The dataset is a collection of 450,000 abstract syntax trees (ASTs) generated from Python GitHub repositories. The dataset aligns well with our task as it also incorporates nodes and edges from the AST as well as the tokenized method names from which the AST was generated. The training, validation, and test splits are provided with the dataset." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "cDdFlkbqQYfq" | |
}, | |
"source": [ | |
"# 1. Install Dependencies" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"id": "Be3f5csQWOBY" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Basic Python dependencies\n", | |
"import os\n", | |
"import time\n", | |
"from collections import defaultdict, namedtuple\n", | |
"import random\n", | |
"random.seed(2)\n", | |
"\n", | |
"# Basic data handling libraries\n", | |
"import numpy as np\n", | |
"from tqdm import tqdm, trange\n", | |
"import pandas as pd\n", | |
"import copy\n", | |
"import json\n", | |
"import matplotlib.pyplot as plt" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "hh8eI-xgRcK9" | |
}, | |
"source": [ | |
"We'll be using [PyG](https://pytorch-geometric.readthedocs.io/en/latest/) (PyTorch Geometric), a library built upon PyTorch to easily write and train Graph Neural Networks (GNNs) on structured datasets.\n", | |
"\n", | |
"Next, we will load the [Open Graph Benchmark](https://ogb.stanford.edu/docs/lsc/) (OGB) dataset from the ogb package. OGB is a collection of realistic, large-scale, and diverse benchmark datasets for machine learning on graphs. The ogb package not only provides data loaders for each dataset but also model evaluators.\n", | |
"\n", | |
"_Note: This cell might take a while (~5 minutes) to run_" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "pOCAIKchQfK2", | |
"outputId": "4ca1b8d7-cc15-42e7-e88b-0af60cb483be" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", | |
"Looking in links: https://pytorch-geometric.com/whl/torch-1.13.1+cu116.html\n", | |
"Requirement already satisfied: torch-scatter in /usr/local/lib/python3.9/dist-packages (2.1.1+pt113cu116)\n", | |
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", | |
"Looking in links: https://pytorch-geometric.com/whl/torch-1.13.1+cu116.html\n", | |
"Requirement already satisfied: torch-sparse in /usr/local/lib/python3.9/dist-packages (0.6.17+pt113cu116)\n", | |
"Requirement already satisfied: scipy in /usr/local/lib/python3.9/dist-packages (from torch-sparse) (1.10.1)\n", | |
"Requirement already satisfied: numpy<1.27.0,>=1.19.5 in /usr/local/lib/python3.9/dist-packages (from scipy->torch-sparse) (1.22.4)\n", | |
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", | |
"Requirement already satisfied: torch-geometric in /usr/local/lib/python3.9/dist-packages (2.2.0)\n", | |
"Requirement already satisfied: numpy in /usr/local/lib/python3.9/dist-packages (from torch-geometric) (1.22.4)\n", | |
"Requirement already satisfied: requests in /usr/local/lib/python3.9/dist-packages (from torch-geometric) (2.27.1)\n", | |
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.9/dist-packages (from torch-geometric) (3.1.2)\n", | |
"Requirement already satisfied: psutil>=5.8.0 in /usr/local/lib/python3.9/dist-packages (from torch-geometric) (5.9.4)\n", | |
"Requirement already satisfied: tqdm in /usr/local/lib/python3.9/dist-packages (from torch-geometric) (4.65.0)\n", | |
"Requirement already satisfied: pyparsing in /usr/local/lib/python3.9/dist-packages (from torch-geometric) (3.0.9)\n", | |
"Requirement already satisfied: scikit-learn in /usr/local/lib/python3.9/dist-packages (from torch-geometric) (1.2.2)\n", | |
"Requirement already satisfied: scipy in /usr/local/lib/python3.9/dist-packages (from torch-geometric) (1.10.1)\n", | |
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.9/dist-packages (from jinja2->torch-geometric) (2.1.2)\n", | |
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.9/dist-packages (from requests->torch-geometric) (1.26.15)\n", | |
"Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.9/dist-packages (from requests->torch-geometric) (2.0.12)\n", | |
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.9/dist-packages (from requests->torch-geometric) (3.4)\n", | |
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.9/dist-packages (from requests->torch-geometric) (2022.12.7)\n", | |
"Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.9/dist-packages (from scikit-learn->torch-geometric) (1.1.1)\n", | |
"Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.9/dist-packages (from scikit-learn->torch-geometric) (3.1.0)\n", | |
" Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", | |
"Requirement already satisfied: ogb in /usr/local/lib/python3.9/dist-packages (1.3.5)\n", | |
"Requirement already satisfied: pandas>=0.24.0 in /usr/local/lib/python3.9/dist-packages (from ogb) (1.4.4)\n", | |
"Requirement already satisfied: urllib3>=1.24.0 in /usr/local/lib/python3.9/dist-packages (from ogb) (1.26.15)\n", | |
"Requirement already satisfied: scikit-learn>=0.20.0 in /usr/local/lib/python3.9/dist-packages (from ogb) (1.2.2)\n", | |
"Requirement already satisfied: numpy>=1.16.0 in /usr/local/lib/python3.9/dist-packages (from ogb) (1.22.4)\n", | |
"Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.9/dist-packages (from ogb) (1.16.0)\n", | |
"Requirement already satisfied: outdated>=0.2.0 in /usr/local/lib/python3.9/dist-packages (from ogb) (0.2.2)\n", | |
"Requirement already satisfied: tqdm>=4.29.0 in /usr/local/lib/python3.9/dist-packages (from ogb) (4.65.0)\n", | |
"Requirement already satisfied: torch>=1.6.0 in /usr/local/lib/python3.9/dist-packages (from ogb) (1.13.1+cu116)\n", | |
"Requirement already satisfied: setuptools>=44 in /usr/local/lib/python3.9/dist-packages (from outdated>=0.2.0->ogb) (67.6.0)\n", | |
"Requirement already satisfied: littleutils in /usr/local/lib/python3.9/dist-packages (from outdated>=0.2.0->ogb) (0.2.2)\n", | |
"Requirement already satisfied: requests in /usr/local/lib/python3.9/dist-packages (from outdated>=0.2.0->ogb) (2.27.1)\n", | |
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.9/dist-packages (from pandas>=0.24.0->ogb) (2022.7.1)\n", | |
"Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.9/dist-packages (from pandas>=0.24.0->ogb) (2.8.2)\n", | |
"Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.9/dist-packages (from scikit-learn>=0.20.0->ogb) (1.1.1)\n", | |
"Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.9/dist-packages (from scikit-learn>=0.20.0->ogb) (3.1.0)\n", | |
"Requirement already satisfied: scipy>=1.3.2 in /usr/local/lib/python3.9/dist-packages (from scikit-learn>=0.20.0->ogb) (1.10.1)\n", | |
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.9/dist-packages (from torch>=1.6.0->ogb) (4.5.0)\n", | |
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.9/dist-packages (from requests->outdated>=0.2.0->ogb) (3.4)\n", | |
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.9/dist-packages (from requests->outdated>=0.2.0->ogb) (2022.12.7)\n", | |
"Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.9/dist-packages (from requests->outdated>=0.2.0->ogb) (2.0.12)\n" | |
] | |
} | |
], | |
"source": [ | |
"!pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.13.1+cu116.html\n", | |
"!pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.13.1+cu116.html\n", | |
"!pip install torch-geometric\n", | |
"!pip install -q git+https://github.com/snap-stanford/deepsnap.git\n", | |
"!pip install ogb" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"id": "1T0jRDVgV3iN" | |
}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import torch_scatter\n", | |
"import torch.nn as nn\n", | |
"import torch.nn.functional as F\n", | |
"\n", | |
"import torch_geometric.nn as pyg_nn\n", | |
"import torch_geometric.utils as pyg_utils\n", | |
"\n", | |
"from torch import Tensor\n", | |
"from typing import Union, Tuple, Optional\n", | |
"from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType, OptTensor)\n", | |
"\n", | |
"from torch.nn import Parameter, Linear\n", | |
"from torch_sparse import SparseTensor, set_diag\n", | |
"from torch_geometric.nn.conv import MessagePassing\n", | |
"from torch_geometric.utils import remove_self_loops, add_self_loops, softmax\n", | |
"from torch_geometric.nn import global_add_pool\n", | |
"\n", | |
"from torch_geometric.data import DataLoader\n", | |
"import torch_geometric.transforms as T" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"id": "JaiQ4WZBfDW7" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Use GPU when available\n", | |
"device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "ipKWwVO1buFr" | |
}, | |
"source": [ | |
"# 2. Setup" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "ZPfG9w44qBxt" | |
}, | |
"source": [ | |
"## 2.1 Load Dataset\n", | |
"\n", | |
"The `ogbg-code2` dataset provides 452,741 different graphs, and the task is to learn a model that can predict a set of tokens that represents the method name for a given graph. The dataset has a pre-defined project split, where the ASTs for the train set are obtained from GitHub projects that do not appear in the validation and test sets.\n", | |
"\n", | |
"Download, extract, and import the `ogbg-code2` dataset.\n", | |
"\n", | |
"https://ogb.stanford.edu/docs/graphprop/" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"id": "eLOsu-8-Y0Tx", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "49aeb1fa-2ce9-4c97-9aa8-3e0eab8a89c3" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"The ogbg-code2 dataset has 452741 graphs\n", | |
"Sample:\n", | |
"Data(edge_index=[2, 243], x=[244, 2], node_is_attributed=[244, 1], node_dfs_order=[244, 1], node_depth=[244, 1], y=[1], num_nodes=244)\n", | |
"Splits: Train: 407976, Val: 22817\n" | |
] | |
} | |
], | |
"source": [ | |
"from ogb.nodeproppred import PygNodePropPredDataset\n", | |
"from ogb.graphproppred import PygGraphPropPredDataset, Evaluator\n", | |
"\n", | |
"dataset_name = 'ogbg-code2'\n", | |
"# Load the dataset\n", | |
"dataset = PygGraphPropPredDataset(name=dataset_name)\n", | |
"print('The {} dataset has {} graphs'.format(dataset_name, len(dataset)))\n", | |
"print('Sample:')\n", | |
"# Extract sample graph\n", | |
"print(dataset[0])\n", | |
"split_idx = dataset.get_idx_split()\n", | |
"train_idx, valid_idx = split_idx['train'], split_idx['valid']\n", | |
"print(f'Splits: Train: {len(train_idx)}, Val: {len(valid_idx)}')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "9ojDO44URLu8" | |
}, | |
"source": [ | |
"Load provided mapping for node types and attributes." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"id": "kPkqdj4wQ7Vy" | |
}, | |
"outputs": [], | |
"source": [ | |
"nodetypes_mapping = pd.read_csv(os.path.join(dataset.root, 'mapping', 'typeidx2type.csv.gz'))\n", | |
"nodeattributes_mapping = pd.read_csv(os.path.join(dataset.root, 'mapping', 'attridx2attr.csv.gz'))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "Hon0DGcApn2f" | |
}, | |
"source": [ | |
"## 2.2 Model Configuration" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "OOmHc6Evptds" | |
}, | |
"source": [ | |
"All model configuration and hyperparameters are stored in a single `Config` object as defined below. The most notable parameters is: `max_vocab_size` which determines the size of our vocabulary. Since our model is generating a set of tokens, we can limit the predictions by having a fixed length vocabulary." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"id": "QoeuE8aDiXnb" | |
}, | |
"outputs": [], | |
"source": [ | |
"batch_size, emb_dim, lr, max_iter = 128, 256, .005, 25\n", | |
"cfg = {\n", | |
" 'device': device,\n", | |
" 'dataset_name': dataset_name,\n", | |
" 'model_save_path': f'gin_{batch_size}bz_{emb_dim}emb_{lr}lr_{max_iter}it.net',\n", | |
" 'history_write_path': f'gin_{batch_size}bz_{emb_dim}emb_{lr}lr_{max_iter}it_history.json',\n", | |
" 'max_vocab_size': 5000,\n", | |
" 'max_seq_len': 5,\n", | |
" 'batch_size': batch_size,\n", | |
" 'emb_dim': emb_dim,\n", | |
" 'num_nodetypes': len(nodetypes_mapping['type']),\n", | |
" 'num_nodeattributes': len(nodeattributes_mapping['attr']),\n", | |
" 'max_depth': 20,\n", | |
" 'num_layers': 5,\n", | |
" 'max_iter': max_iter,\n", | |
" 'dropout': 0.2,\n", | |
" 'lr': lr\n", | |
"}\n", | |
"Cfg = namedtuple('Cfg', cfg)\n", | |
"cfg = Cfg(**cfg)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "8gy3VjT_qTWK" | |
}, | |
"source": [ | |
"## 2.3 Build Vocabulary\n", | |
"\n", | |
"Here we extract the vocabulary from ground truth sequences (list of words) in training set." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"id": "aC3op-nklRXb", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "d2402390-26cf-4974-ce85-f32e8cc7ab70" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Top 5000 words proportion: 0.9566695094108582\n" | |
] | |
} | |
], | |
"source": [ | |
"# Given a sequence of words and word-to-index mapping, generate a fixed length tensor representation\n", | |
"def encode_seq(seq, v2i, PAD):\n", | |
" words_idx = [v2i[word] for word in seq[:cfg.max_seq_len]]\n", | |
" pad_factor = max(0, cfg.max_seq_len - len(words_idx))\n", | |
" return torch.as_tensor(words_idx + [PAD]*pad_factor, dtype = torch.long, device = cfg.device)\n", | |
"\n", | |
"def build_vocab(target, target_idx_list):\n", | |
" # stores mapping between words and indices\n", | |
" v2i = defaultdict(lambda: len(v2i))\n", | |
" VOID = v2i['<void>']\n", | |
" PAD = v2i['<pad>']\n", | |
" UNK = v2i['<unk>']\n", | |
"\n", | |
" target_encodings = torch.stack([encode_seq(target[idx], v2i, PAD) for idx in target_idx_list]).to(cfg.device)\n", | |
" \n", | |
" # so far we have an infinite length vocabulary which is not feasible\n", | |
" # restrict to top-k (cfg.max_vocab_size) appearing words\n", | |
" counts = torch.bincount(target_encodings.view(-1))\n", | |
" topk = torch.topk(counts, k = cfg.max_vocab_size)[1]\n", | |
" print(f'Top {cfg.max_vocab_size} words proportion: {(counts[topk].sum() / counts.sum()).cpu().numpy()}')\n", | |
"\n", | |
" topk_keep = torch.as_tensor([VOID,PAD,UNK], dtype = torch.long, device = cfg.device)\n", | |
" topk = torch.cat([topk, topk_keep])\n", | |
"\n", | |
" # rewire the indices based on the top-k search above\n", | |
" i2v = {v: k for k, v in v2i.items()}\n", | |
" v2i = {}\n", | |
" for k, new_k in zip(sorted(topk.unique().cpu().numpy()), range(len(topk))):\n", | |
" v2i[i2v[k]] = new_k\n", | |
"\n", | |
" v2i = defaultdict(lambda: UNK, v2i)\n", | |
" i2v = {v: k for k, v in v2i.items()}\n", | |
" return v2i, i2v, UNK, PAD\n", | |
"\n", | |
"# Build vocab from existing labels in *training set* only\n", | |
"v2i, i2v, UNK, PAD = build_vocab(dataset.data.y, train_idx)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"For all the examples in the training set, roughly **96%** are in our fixed-length vocabulary. This ratio provides a litmus test of whether we need to tune `max_vocab_len`." | |
], | |
"metadata": { | |
"id": "PLKtwrCRhgR5" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "rWVhGhMVW1BN" | |
}, | |
"source": [ | |
"## 2.4 Augment Features\n", | |
"\n", | |
"PyG's transforms are a general way to modify and customize Data objects. Here we define two such transforms: \n", | |
"1. Encode labels as Tensors with `encode_target_tensor`.\n", | |
"2. Augment next-token edges with inverse relation and edge attributes with `augment_edge`.\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"id": "29RhVmeiXkAP" | |
}, | |
"outputs": [], | |
"source": [ | |
"def encode_target_tensor(data):\n", | |
" data.y_tensor = encode_seq(data.y, v2i, PAD).unsqueeze(dim = 0)\n", | |
" return data\n", | |
"\n", | |
"# The augment_edge transform is originally defined in the official OGB example here:\n", | |
"# https://github.com/snap-stanford/ogb (source)\n", | |
"def augment_edge(data):\n", | |
" '''\n", | |
" Input:\n", | |
" data: PyG data object\n", | |
" Output:\n", | |
" data (edges are augmented in the following ways):\n", | |
" data.edge_index: Added next-token edge. The inverse edges were also added.\n", | |
" data.edge_attr (torch.Long):\n", | |
" data.edge_attr[:,0]: whether it is AST edge (0) for next-token edge (1)\n", | |
" data.edge_attr[:,1]: whether it is original direction (0) or inverse direction (1)\n", | |
" '''\n", | |
" ##### AST edge\n", | |
" edge_index_ast = data.edge_index\n", | |
" edge_attr_ast = torch.zeros((edge_index_ast.size(1), 2))\n", | |
"\n", | |
" ##### Inverse AST edge\n", | |
" edge_index_ast_inverse = torch.stack([edge_index_ast[1], edge_index_ast[0]], dim = 0)\n", | |
" edge_attr_ast_inverse = torch.cat([torch.zeros(edge_index_ast_inverse.size(1), 1), torch.ones(edge_index_ast_inverse.size(1), 1)], dim = 1)\n", | |
"\n", | |
"\n", | |
" ##### Next-token edge\n", | |
"\n", | |
" ## Since the nodes are already sorted in dfs ordering in our case, we can just do the following.\n", | |
" attributed_node_idx_in_dfs_order = torch.where(data.node_is_attributed.view(-1,) == 1)[0]\n", | |
"\n", | |
" ## build next token edge\n", | |
" # Given: attributed_node_idx_in_dfs_order\n", | |
" # [1, 3, 4, 5, 8, 9, 12]\n", | |
" # Output:\n", | |
" # [[1, 3, 4, 5, 8, 9]\n", | |
" # [3, 4, 5, 8, 9, 12]\n", | |
" edge_index_nextoken = torch.stack([attributed_node_idx_in_dfs_order[:-1], attributed_node_idx_in_dfs_order[1:]], dim = 0)\n", | |
" edge_attr_nextoken = torch.cat([torch.ones(edge_index_nextoken.size(1), 1), torch.zeros(edge_index_nextoken.size(1), 1)], dim = 1)\n", | |
"\n", | |
"\n", | |
" ##### Inverse next-token edge\n", | |
" edge_index_nextoken_inverse = torch.stack([edge_index_nextoken[1], edge_index_nextoken[0]], dim = 0)\n", | |
" edge_attr_nextoken_inverse = torch.ones((edge_index_nextoken.size(1), 2))\n", | |
"\n", | |
"\n", | |
" data.edge_index = torch.cat([edge_index_ast, edge_index_ast_inverse, edge_index_nextoken, edge_index_nextoken_inverse], dim = 1)\n", | |
" data.edge_attr = torch.cat([edge_attr_ast, edge_attr_ast_inverse, edge_attr_nextoken, edge_attr_nextoken_inverse], dim = 0)\n", | |
"\n", | |
" return data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"id": "j80XYIl_f4gI" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Compose multiple transforms and apply it to the dataset\n", | |
"dataset.transform = T.Compose([augment_edge, encode_target_tensor])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "jcjP4V9IKq30" | |
}, | |
"source": [ | |
"## 2.5 Create Data Loaders" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "78Mho4d5zVRZ" | |
}, | |
"source": [ | |
"PyG automatically takes care of batching multiple graphs into a single giant graph with the help of the `DataLoader` class:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "8knLefdXNrYd", | |
"outputId": "964a74fe-97cb-4844-8964-b71f1f50786e" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Number of batches: Train: 3188, Val: 179\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"/usr/local/lib/python3.9/dist-packages/torch_geometric/deprecation.py:12: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n", | |
" warnings.warn(out)\n" | |
] | |
} | |
], | |
"source": [ | |
"# use evaluation metrics defined in the OGB dataset\n", | |
"evaluator = Evaluator(cfg.dataset_name)\n", | |
"\n", | |
"# create data loaders for each split\n", | |
"# shuffle = true for training set ensures data is reshuffled for every epoch\n", | |
"train_loader = DataLoader(dataset[train_idx], batch_size = cfg.batch_size, shuffle = True) # 407976\n", | |
"valid_loader = DataLoader(dataset[valid_idx], batch_size = cfg.batch_size, shuffle = False) # 22817\n", | |
"print(f'Number of batches: Train: {len(train_loader)}, Val: {len(valid_loader)}')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "MFdpuYnKRuWA" | |
}, | |
"source": [ | |
"# 3. Create Model\n", | |
"\n", | |
"Here are the steps to train a GNN for graph property prediction:\n", | |
"\n", | |
"1. Embed each node by performing multiple rounds of message passing\n", | |
"2. Aggregate node embeddings into a unified graph embedding (readout layer)\n", | |
"3. Train a final classifier on the graph embedding\n", | |
"\n", | |
"\n", | |
"\n", | |
"The figure above illustrates a high-level architecture of our GNN model.\n", | |
"\n", | |
"For [Graph Isomorphism Network](https://cs.stanford.edu/people/jure/pubs/gin-iclr19.pdf) (GIN), the readout layer simply takes the sum of \n", | |
"node embeddings. PyG provides this functionality via `global_sum_pool`, which takes in the node embeddings of all nodes in the mini-batch and the assignment vector batch to compute a graph embedding of size `[batch_size, hidden_dim]` for each graph in the batch.\n", | |
"\n", | |
"`NodeEncoder` defines 3 `nn.Embedding` layers to encode node features: type, attribute, and depth." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "qyhDpe39nQ_E" | |
}, | |
"source": [ | |
"## 3.1 Define Model\n", | |
"\n", | |
"The final architecture for applying GNNs to the task of graph property prediction then looks as follows and allows for complete end-to-end training:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": { | |
"id": "PAQqTDiwneC7" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Embed node type, attribute, and depth\n", | |
"class NodeEncoder(torch.nn.Module):\n", | |
" def __init__(self, cfg):\n", | |
" super(NodeEncoder, self).__init__()\n", | |
" self.type_enc = nn.Embedding(cfg.num_nodetypes, cfg.emb_dim)\n", | |
" self.attr_enc = nn.Embedding(cfg.num_nodeattributes, cfg.emb_dim)\n", | |
" self.depth_enc = nn.Embedding(cfg.max_depth + 1, cfg.emb_dim)\n", | |
" self.cfg = cfg\n", | |
" self.apply(self._init_weights)\n", | |
"\n", | |
" def _init_weights(self, module):\n", | |
" for param in module.parameters():\n", | |
" param.data.uniform_(.0, 1.0)\n", | |
"\n", | |
" def forward(self, x, depth):\n", | |
" depth[depth > self.cfg.max_depth] = self.cfg.max_depth\n", | |
" # add the node feature embeddings\n", | |
" return self.type_enc(x[:,0]) + self.attr_enc(x[:,1]) + self.depth_enc(depth)\n", | |
"\n", | |
"# Apply GIN convolution given the graph structure\n", | |
"# message propagation is enabled by inheriting from MessagePassing\n", | |
"# use sum \"add\" aggregation\n", | |
"class GINConv(MessagePassing):\n", | |
" def __init__(self, emb_dim):\n", | |
" super(GINConv, self).__init__(aggr=\"add\")\n", | |
" self.mlp = nn.Sequential(\n", | |
" nn.Linear(emb_dim, 2*emb_dim), \n", | |
" nn.BatchNorm1d(2*emb_dim), \n", | |
" nn.ReLU(), \n", | |
" nn.Linear(2*emb_dim, emb_dim))\n", | |
" self.eps = nn.Parameter(torch.Tensor([0]))\n", | |
" self.edge_encoder = nn.Linear(2, emb_dim)\n", | |
" self.apply(self._init_weights)\n", | |
"\n", | |
" def _init_weights(self, module):\n", | |
" for param in module.parameters():\n", | |
" param.data.uniform_(.0, 1.0)\n", | |
"\n", | |
" def forward(self, x, edge_index, edge_attr):\n", | |
" emb = self.edge_encoder(edge_attr)\n", | |
" emb = self.propagate(edge_index, x = x, edge_emb = emb)\n", | |
" emb += x * (self.eps + 1)\n", | |
" res = self.mlp(emb)\n", | |
" return res\n", | |
"\n", | |
" def message(self, x_j, edge_emb):\n", | |
" return F.relu(x_j + edge_emb)\n", | |
"\n", | |
"# Node embedding GNN\n", | |
"# Encodes node attributes with provided NodeEncoder\n", | |
"# then applies Message Passing with GINConv\n", | |
"# Intra layers: Linear --> Batch Norm --> Dropout --> Activation\n", | |
"class NodeGIN(torch.nn.Module):\n", | |
" def __init__(self, cfg):\n", | |
" super(NodeGIN, self).__init__()\n", | |
" self.encoder = NodeEncoder(cfg)\n", | |
" self.convs = nn.ModuleList([GINConv(cfg.emb_dim) for l in range(cfg.num_layers)])\n", | |
" self.bns = nn.ModuleList([nn.BatchNorm1d(cfg.emb_dim) for l in range(cfg.num_layers)])\n", | |
" self.cfg = cfg\n", | |
" self.apply(self._init_weights)\n", | |
"\n", | |
" def _init_weights(self, module):\n", | |
" for param in module.parameters():\n", | |
" param.data.uniform_(.0, 1.0)\n", | |
"\n", | |
" def forward(self, data_batch):\n", | |
" x, edge_index, edge_attr, node_depth = data_batch.x, data_batch.edge_index, data_batch.edge_attr, data_batch.node_depth\n", | |
" h_last = self.encoder(x, node_depth.view(-1,))\n", | |
" for l in range(self.cfg.num_layers):\n", | |
" h = self.convs[l](h_last, edge_index, edge_attr)\n", | |
" h = self.bns[l](h)\n", | |
" h = F.relu(h) if l < (self.cfg.num_layers-1) else h\n", | |
" h = F.dropout(h, p = self.cfg.dropout, training = self.training)\n", | |
" h_last = h\n", | |
" return h_last\n", | |
"\n", | |
"# Wrapper GNN to generate final predictions: \n", | |
"# sequence of tokens that form method name for given graph\n", | |
"# Intra layers: Linear --> Batch Norm --> Dropout --> Activation --> Aggregation\n", | |
"# GIN aggregate with global_add_pool\n", | |
"class GIN(nn.Module):\n", | |
" def __init__(self, v2i, cfg):\n", | |
" super(GIN, self).__init__()\n", | |
" self.node_GIN = NodeGIN(cfg)\n", | |
" self.pred_lins = nn.ModuleList([nn.Linear(cfg.emb_dim, len(v2i)) for i in range(cfg.max_seq_len)])\n", | |
" self.v2i = v2i\n", | |
" self.cfg = cfg\n", | |
" self.apply(self._init_weights)\n", | |
" \n", | |
" def _init_weights(self, module):\n", | |
" for param in module.parameters():\n", | |
" param.data.uniform_(.0, 1.0)\n", | |
"\n", | |
" def forward(self, data_batch):\n", | |
" node_emb = self.node_GIN(data_batch)\n", | |
" graph_emb = global_add_pool(node_emb, data_batch.batch)\n", | |
" preds = [lin(graph_emb) for lin in self.pred_lins]\n", | |
" return preds\n", | |
" \n", | |
" def save(self):\n", | |
" print('')\n", | |
" print(f'Saving model: {self.cfg.model_save_path}')\n", | |
" summary = dict(params=self.cfg, v2i=dict(self.v2i), state=self.state_dict())\n", | |
" torch.save(summary, self.cfg.model_save_path)\n", | |
"\n", | |
" @staticmethod\n", | |
" def load(model_path_fname, use_cuda = True):\n", | |
" data = torch.load(model_path_fname)\n", | |
" cfg, v2i, state = data['params'], data['v2i'], data['state']\n", | |
" model = GIN(v2i, cfg)\n", | |
" model.load_state_dict(state)\n", | |
" if use_cuda: model.cuda()\n", | |
" model.eval()\n", | |
" return model" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "QNSBk2iKnXDq" | |
}, | |
"source": [ | |
"## 3.2 Create Model" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "H0_3qNYL6hfV" | |
}, | |
"source": [ | |
"Instantiate the model and load it on the GPU if available. Use [Adam optimizer](https://pytorch.org/docs/stable/generated/torch.optim.Adam.html) and [Cross-Entropy Loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) for multi-class classificaiton." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": { | |
"id": "vcCnD9j9R3Vb" | |
}, | |
"outputs": [], | |
"source": [ | |
"model = GIN(v2i, cfg).to(cfg.device)\n", | |
"optimizer = torch.optim.Adam(model.parameters(), lr = cfg.lr)\n", | |
"ceLoss = torch.nn.CrossEntropyLoss()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "MUhZo6hsVi4Q" | |
}, | |
"source": [ | |
"## 3.3 Create Evaluator\n", | |
"\n", | |
"Wrapper to evaluate on validation and test sets using the OGB provided evaluation metric. The `decode` function takes a sequence of predicted tokens and converts them back to words using the reverse vocabulary mapping `i2v`." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": { | |
"id": "0z12QwKXVohb" | |
}, | |
"outputs": [], | |
"source": [ | |
"class Eval():\n", | |
" def __init__(self, model, loaders, evaluator, i2v, PAD, cfg):\n", | |
" self.model = model\n", | |
" self.loaders = loaders\n", | |
" self.ogb_eval = evaluator.eval\n", | |
" self.i2v = i2v\n", | |
" self.filter_tokens = torch.as_tensor([PAD], dtype = torch.long, device = cfg.device)\n", | |
" self.cfg = cfg\n", | |
" \n", | |
" def decode(self, pred):\n", | |
" res = []\n", | |
" for p in pred:\n", | |
" mask = torch.isin(p, self.filter_tokens, invert = True)\n", | |
" p_trimmed = p[mask]\n", | |
" res.append([self.i2v[i] for i in p_trimmed.cpu().numpy()])\n", | |
" return res\n", | |
"\n", | |
" def eval(self):\n", | |
" loaders_metrics = []\n", | |
" for loader in self.loaders:\n", | |
" preds_acc = []\n", | |
" labels_acc = []\n", | |
" for batch_id, batch in enumerate(tqdm(loader, desc='Eval Batch')):\n", | |
" batch = batch.to(self.cfg.device)\n", | |
" # disable auto grad\n", | |
" with torch.no_grad():\n", | |
" pred = self.model(batch)\n", | |
" pred_max = torch.cat([torch.argmax(p, dim = 1).view(-1,1) for p in pred], dim = 1)\n", | |
" pred_max_decoded = self.decode(pred_max)\n", | |
" preds_acc.extend(pred_max_decoded)\n", | |
" labels_acc.extend([*batch.y])\n", | |
" metrics = self.ogb_eval({'seq_pred': preds_acc, 'seq_ref': labels_acc})\n", | |
" metrics['acc'] = Eval.compute_accuracy(preds_acc, labels_acc)\n", | |
" loaders_metrics.append(metrics)\n", | |
" return loaders_metrics\n", | |
"\n", | |
" @staticmethod\n", | |
" def compute_accuracy(seq_pred, seq_ref):\n", | |
" acc = []\n", | |
" for l, p in zip(seq_ref, seq_pred):\n", | |
" label = set(l)\n", | |
" prediction = set(p)\n", | |
" n = len(label)\n", | |
" true_positive = len(label.intersection(prediction))\n", | |
" false_positive = len(prediction - label)\n", | |
" false_negative = len(label - prediction)\n", | |
" true_negative = max(n - (true_positive + false_positive + false_negative), 0)\n", | |
" acc.append((true_positive + true_negative) / n)\n", | |
" return np.average(acc)\n", | |
"\n", | |
"model_evaluator = Eval(model, [valid_loader], evaluator, i2v, PAD, cfg)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "k4R6uvBBSeFc" | |
}, | |
"source": [ | |
"# 4. Training\n", | |
"\n", | |
"Finally let's train our network to see how well it performs on the training as well as test sets.\n", | |
"\n", | |
"_Note: Due to the size of the dataset (~450k graphs), the expected training time for each iteration is about 70 minutes on a standard GPU._" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "GEwQQONSSg6t", | |
"outputId": "25236f72-fbc4-4e08-a8bc-2f1fe4d69474" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"Training Batch: 100%|██████████| 3188/3188 [07:23<00:00, 7.18it/s]\n", | |
"Eval Batch: 100%|██████████| 179/179 [00:14<00:00, 12.05it/s]\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"Saving model: gin_128bz_256emb_0.005lr_25it.net\n", | |
"\n", | |
"It: 01, Loss: 25.8113, Accuracy: 0.0128, F1: 0.0176 \n", | |
"------------------------------------------------------------\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"Training Batch: 100%|██████████| 3188/3188 [06:23<00:00, 8.32it/s]\n", | |
"Eval Batch: 100%|██████████| 179/179 [00:15<00:00, 11.58it/s]\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"Saving model: gin_128bz_256emb_0.005lr_25it.net\n", | |
"\n", | |
"It: 02, Loss: 3.8183, Accuracy: 0.0346, F1: 0.0478 \n", | |
"------------------------------------------------------------\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"Training Batch: 100%|██████████| 3188/3188 [06:23<00:00, 8.32it/s]\n", | |
"Eval Batch: 100%|██████████| 179/179 [00:14<00:00, 12.09it/s]\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"It: 03, Loss: 3.7375, Accuracy: 0.0306, F1: 0.0422 \n", | |
"------------------------------------------------------------\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"Training Batch: 100%|██████████| 3188/3188 [06:20<00:00, 8.38it/s]\n", | |
"Eval Batch: 100%|██████████| 179/179 [00:15<00:00, 11.85it/s]\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"It: 04, Loss: 3.7877, Accuracy: 0.0326, F1: 0.0431 \n", | |
"------------------------------------------------------------\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"Training Batch: 100%|██████████| 3188/3188 [06:24<00:00, 8.30it/s]\n", | |
"Eval Batch: 100%|██████████| 179/179 [00:14<00:00, 12.23it/s]\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"It: 05, Loss: 3.8740, Accuracy: 0.0300, F1: 0.0400 \n", | |
"------------------------------------------------------------\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"Training Batch: 100%|██████████| 3188/3188 [06:17<00:00, 8.44it/s]\n", | |
"Eval Batch: 100%|██████████| 179/179 [00:14<00:00, 12.09it/s]\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"Saving model: gin_128bz_256emb_0.005lr_25it.net\n", | |
"\n", | |
"It: 06, Loss: 3.7172, Accuracy: 0.0369, F1: 0.0499 \n", | |
"------------------------------------------------------------\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"Training Batch: 100%|██████████| 3188/3188 [06:17<00:00, 8.45it/s]\n", | |
"Eval Batch: 100%|██████████| 179/179 [00:14<00:00, 12.13it/s]\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"It: 07, Loss: 3.6980, Accuracy: 0.0259, F1: 0.0345 \n", | |
"------------------------------------------------------------\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"Training Batch: 100%|██████████| 3188/3188 [06:19<00:00, 8.39it/s]\n", | |
"Eval Batch: 100%|██████████| 179/179 [00:14<00:00, 12.37it/s]\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"It: 08, Loss: 3.5636, Accuracy: 0.0347, F1: 0.0458 \n", | |
"------------------------------------------------------------\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"Training Batch: 100%|██████████| 3188/3188 [06:15<00:00, 8.49it/s]\n", | |
"Eval Batch: 100%|██████████| 179/179 [00:14<00:00, 12.43it/s]\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"It: 09, Loss: 3.4992, Accuracy: 0.0225, F1: 0.0289 \n", | |
"------------------------------------------------------------\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"Training Batch: 100%|██████████| 3188/3188 [06:17<00:00, 8.45it/s]\n", | |
"Eval Batch: 100%|██████████| 179/179 [00:14<00:00, 12.36it/s]\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"It: 10, Loss: 3.6468, Accuracy: 0.0338, F1: 0.0420 \n", | |
"------------------------------------------------------------\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"Training Batch: 100%|██████████| 3188/3188 [06:15<00:00, 8.50it/s]\n", | |
"Eval Batch: 100%|██████████| 179/179 [00:14<00:00, 12.25it/s]\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"It: 11, Loss: 3.4158, Accuracy: 0.0428, F1: 0.0488 \n", | |
"------------------------------------------------------------\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"Training Batch: 100%|██████████| 3188/3188 [06:19<00:00, 8.41it/s]\n", | |
"Eval Batch: 100%|██████████| 179/179 [00:14<00:00, 12.38it/s]\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"Saving model: gin_128bz_256emb_0.005lr_25it.net\n", | |
"\n", | |
"It: 12, Loss: 3.5234, Accuracy: 0.0438, F1: 0.0560 \n", | |
"------------------------------------------------------------\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"Training Batch: 100%|██████████| 3188/3188 [06:19<00:00, 8.39it/s]\n", | |
"Eval Batch: 100%|██████████| 179/179 [00:15<00:00, 11.81it/s]\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"Saving model: gin_128bz_256emb_0.005lr_25it.net\n", | |
"\n", | |
"It: 13, Loss: 3.4410, Accuracy: 0.0471, F1: 0.0596 \n", | |
"------------------------------------------------------------\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"Training Batch: 100%|██████████| 3188/3188 [06:22<00:00, 8.33it/s]\n", | |
"Eval Batch: 100%|██████████| 179/179 [00:14<00:00, 12.12it/s]\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"It: 14, Loss: 3.3213, Accuracy: 0.0478, F1: 0.0569 \n", | |
"------------------------------------------------------------\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"Training Batch: 100%|██████████| 3188/3188 [06:20<00:00, 8.37it/s]\n", | |
"Eval Batch: 100%|██████████| 179/179 [00:15<00:00, 11.68it/s]\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"Saving model: gin_128bz_256emb_0.005lr_25it.net\n", | |
"\n", | |
"It: 15, Loss: 3.3446, Accuracy: 0.0487, F1: 0.0619 \n", | |
"------------------------------------------------------------\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"Training Batch: 100%|██████████| 3188/3188 [06:15<00:00, 8.48it/s]\n", | |
"Eval Batch: 100%|██████████| 179/179 [00:15<00:00, 11.85it/s]\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"It: 16, Loss: 3.4295, Accuracy: 0.0461, F1: 0.0562 \n", | |
"------------------------------------------------------------\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"Training Batch: 100%|██████████| 3188/3188 [06:18<00:00, 8.42it/s]\n", | |
"Eval Batch: 100%|██████████| 179/179 [00:14<00:00, 12.28it/s]\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"It: 17, Loss: 3.2419, Accuracy: 0.0523, F1: 0.0603 \n", | |
"------------------------------------------------------------\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"Training Batch: 100%|██████████| 3188/3188 [06:20<00:00, 8.37it/s]\n", | |
"Eval Batch: 100%|██████████| 179/179 [00:15<00:00, 11.65it/s]\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"Saving model: gin_128bz_256emb_0.005lr_25it.net\n", | |
"\n", | |
"It: 18, Loss: 3.2632, Accuracy: 0.0540, F1: 0.0641 \n", | |
"------------------------------------------------------------\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"Training Batch: 100%|██████████| 3188/3188 [06:22<00:00, 8.33it/s]\n", | |
"Eval Batch: 100%|██████████| 179/179 [00:14<00:00, 12.11it/s]\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"Saving model: gin_128bz_256emb_0.005lr_25it.net\n", | |
"\n", | |
"It: 19, Loss: 3.3026, Accuracy: 0.0603, F1: 0.0707 \n", | |
"------------------------------------------------------------\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"Training Batch: 100%|██████████| 3188/3188 [06:21<00:00, 8.36it/s]\n", | |
"Eval Batch: 100%|██████████| 179/179 [00:14<00:00, 12.19it/s]\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"It: 20, Loss: 3.2678, Accuracy: 0.0477, F1: 0.0597 \n", | |
"------------------------------------------------------------\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"Training Batch: 100%|██████████| 3188/3188 [06:20<00:00, 8.39it/s]\n", | |
"Eval Batch: 100%|██████████| 179/179 [00:14<00:00, 12.20it/s]\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"Saving model: gin_128bz_256emb_0.005lr_25it.net\n", | |
"\n", | |
"It: 21, Loss: 3.1666, Accuracy: 0.0586, F1: 0.0711 \n", | |
"------------------------------------------------------------\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"Training Batch: 100%|██████████| 3188/3188 [06:21<00:00, 8.35it/s]\n", | |
"Eval Batch: 100%|██████████| 179/179 [00:15<00:00, 11.82it/s]\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"It: 22, Loss: 3.1788, Accuracy: 0.0558, F1: 0.0669 \n", | |
"------------------------------------------------------------\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"Training Batch: 100%|██████████| 3188/3188 [06:20<00:00, 8.38it/s]\n", | |
"Eval Batch: 100%|██████████| 179/179 [00:14<00:00, 12.45it/s]\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"It: 23, Loss: 3.1845, Accuracy: 0.0515, F1: 0.0625 \n", | |
"------------------------------------------------------------\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"Training Batch: 100%|██████████| 3188/3188 [06:17<00:00, 8.44it/s]\n", | |
"Eval Batch: 100%|██████████| 179/179 [00:14<00:00, 12.37it/s]\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"Saving model: gin_128bz_256emb_0.005lr_25it.net\n", | |
"\n", | |
"It: 24, Loss: 3.1288, Accuracy: 0.0703, F1: 0.0822 \n", | |
"------------------------------------------------------------\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"Training Batch: 100%|██████████| 3188/3188 [06:16<00:00, 8.46it/s]\n", | |
"Eval Batch: 100%|██████████| 179/179 [00:14<00:00, 12.37it/s]" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"It: 25, Loss: 3.0778, Accuracy: 0.0602, F1: 0.0735 \n", | |
"------------------------------------------------------------\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"# store metrics for each iteration\n", | |
"history = { 'train_loss': [], 'val_metrics': [], 'val_acc': [], 'val_f1': [] }\n", | |
"\n", | |
"for it in range(cfg.max_iter):\n", | |
" # place model in training mode\n", | |
" model.train()\n", | |
" train_loss = .0\n", | |
" # iterate through batches defined in the dataloader\n", | |
" for batch_id, batch in enumerate(tqdm(train_loader, desc = 'Training Batch')):\n", | |
" optimizer.zero_grad()\n", | |
" batch = batch.to(cfg.device)\n", | |
" pred = model(batch)\n", | |
" curr_loss = .0\n", | |
" for i in range(len(pred)):\n", | |
" curr_loss += ceLoss(pred[i].to(torch.float32), batch.y_tensor[:,i])\n", | |
" curr_loss = curr_loss / len(pred)\n", | |
" \n", | |
" curr_loss.backward()\n", | |
" optimizer.step()\n", | |
" train_loss += curr_loss.item()\n", | |
" avg_loss = train_loss / (batch_id + 1)\n", | |
" history['train_loss'].append(avg_loss)\n", | |
"\n", | |
" # place model in evaluation mode (ignore gradient tracking)\n", | |
" model.eval()\n", | |
" \n", | |
" # evaluate model based on validation and test sets\n", | |
" val_metrics = model_evaluator.eval()[0] # dataset.eval_metric\n", | |
" history['val_metrics'].append(val_metrics)\n", | |
" history['val_acc'].append(val_metrics['acc'])\n", | |
" history['val_f1'].append(val_metrics['F1'])\n", | |
"\n", | |
" # determine best model based on validation accuracy\n", | |
" best_it = np.argmax(np.array(history['val_f1']))\n", | |
" if best_it == it:\n", | |
" model.save()\n", | |
"\n", | |
" print('')\n", | |
" print(f'It: {(it+1):02d}, '\n", | |
" f'Loss: {avg_loss:.4f}, '\n", | |
" f'Accuracy: {val_metrics[\"acc\"]:.4f}, '\n", | |
" f'F1: {val_metrics[\"F1\"]:.4f} ')\n", | |
" print('-'*60)\n", | |
"\n", | |
" with open(cfg.history_write_path, 'w') as f:\n", | |
" json.dump(history, f)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "SAnBazBFWx0K" | |
}, | |
"source": [ | |
"# 5. Results\n", | |
"\n", | |
"Plot the training loss with validation/test accuracy to assess model performance and fine-tune hyperparameters." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": { | |
"id": "WMEhrYG5Xs8C", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 580 | |
}, | |
"outputId": "4ad6ac17-87d6-40db-9280-4ac5d74119ed" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
], | |
"image/png": "\n" | |
}, | |
"metadata": { | |
"needs_background": "light" | |
} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
], | |
"image/png": "\n" | |
}, | |
"metadata": { | |
"needs_background": "light" | |
} | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"Best model (it: 24): Loss: 3.1288, Accuracy: 0.0703, F1: 0.0822, Precision: 0.1112, Recall: 0.0703 \n" | |
] | |
} | |
], | |
"source": [ | |
"plt.title(f'Training {cfg.max_iter} iterations on {dataset.name}')\n", | |
"plt.plot(np.log(history['train_loss']), label=\"Training Loss (log scale)\")\n", | |
"plt.legend()\n", | |
"plt.show()\n", | |
"\n", | |
"plt.title(f'Evaluation on {dataset.name}')\n", | |
"plt.plot(history['val_acc'], label=\"Validation Accuracy\")\n", | |
"plt.legend()\n", | |
"plt.show()\n", | |
"\n", | |
"best_it = np.argmax(np.array(history['val_f1']))\n", | |
"print('')\n", | |
"print(f'Best model (it: {best_it+1}): '\n", | |
" f'Loss: {history[\"train_loss\"][best_it]:.4f}, '\n", | |
" f'Accuracy: {history[\"val_acc\"][best_it]:.4f}, '\n", | |
" f'F1: {history[\"val_f1\"][best_it]:.4f}, '\n", | |
" f'Precision: {history[\"val_metrics\"][best_it][\"precision\"]:.4f}, '\n", | |
" f'Recall: {history[\"val_metrics\"][best_it][\"recall\"]:.4f} ')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "S49PVlqe9IIs" | |
}, | |
"source": [ | |
"# 6. Summary\n", | |
"\n", | |
"You have learned how graphs can be batched together for better GPU utilization, and how to apply readout layers for obtaining graph embeddings rather than node embeddings.\n", | |
"\n", | |
"We performed the experiments using GIN architecture on a Python code dataset. We highlight how providing structured information from the AST to the Message Passing framework results in better performance.\n", | |
"\n", | |
"Finally, you can refer to the [OGB leaderboard](https://ogb.stanford.edu/docs/leader_graphprop/#ogbg-code2) to help you implement more complex GNN architectures and extend this experiment." | |
] | |
} | |
], | |
"metadata": { | |
"colab": { | |
"provenance": [] | |
}, | |
"gpuClass": "standard", | |
"kernelspec": { | |
"display_name": "Python 3", | |
"name": "python3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"accelerator": "GPU" | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment