Last active
October 12, 2020 07:03
-
-
Save yzh119/d78d5f21075b8362d6dbb6b12f4b5382 to your computer and use it in GitHub Desktop.
Training GraphSAGE w/ fp16 in DGL.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Training graphsage w/ fp16. | |
Usage: | |
python train_full.py --gpu 0 --fp16 --dataset | |
Note that GradScaler is not acitvated because the model successfully converges | |
without gradient scaling. | |
DGL's Message Passing APIs are not compatible with fp16 yet, hence we disabled | |
autocast when calling these APIs (e.g. apply_edges, update_all), see | |
https://github.com/yzh119/sage-fp16.git | |
In the default setting, using fp16 saves around 1GB GPU memory (from 4052mb | |
to 3042mb). | |
""" | |
import argparse | |
import time | |
import numpy as np | |
import networkx as nx | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import dgl | |
import dgl.function as fn | |
from dgl import DGLGraph | |
from dgl.data import register_data_args, load_data | |
class SAGEConv(nn.Module): | |
def __init__(self, | |
in_feats, | |
out_feats, | |
aggregator_type, | |
feat_drop=0., | |
bias=True, | |
use_fp16=False, | |
norm=None, | |
activation=None): | |
super(SAGEConv, self).__init__() | |
self._in_src_feats, self._in_dst_feats = in_feats, in_feats | |
self._out_feats = out_feats | |
self._aggre_type = aggregator_type | |
self.norm = norm | |
self.feat_drop = nn.Dropout(feat_drop) | |
self.activation = activation | |
self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias) | |
self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias) | |
self.use_fp16 = use_fp16 | |
self.reset_parameters() | |
def reset_parameters(self): | |
gain = nn.init.calculate_gain('relu') | |
nn.init.xavier_uniform_(self.fc_self.weight, gain=gain) | |
nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain) | |
def forward(self, graph, feat): | |
with graph.local_scope(): | |
feat_src = feat_dst = self.feat_drop(feat) | |
h_self = feat_dst | |
graph.srcdata['h'] = feat_src | |
if self.use_fp16: | |
with torch.cuda.amp.autocast(enabled=False): | |
graph.srcdata['h'] = graph.srcdata['h'].float() | |
graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh')) | |
else: | |
graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh')) | |
h_neigh = graph.dstdata['neigh'] | |
# GraphSAGE GCN does not require fc_self. | |
rst = self.fc_self(h_self) + self.fc_neigh(h_neigh) | |
# activation | |
if self.activation is not None: | |
rst = self.activation(rst) | |
# normalization | |
if self.norm is not None: | |
rst = self.norm(rst) | |
return rst | |
class GraphSAGE(nn.Module): | |
def __init__(self, | |
in_feats, | |
n_hidden, | |
n_classes, | |
n_layers, | |
activation, | |
dropout, | |
aggregator_type, | |
use_fp16): | |
super(GraphSAGE, self).__init__() | |
self.layers = nn.ModuleList() | |
self.dropout = nn.Dropout(dropout) | |
self.activation = activation | |
# input layer | |
self.layers.append(SAGEConv(in_feats, n_hidden, aggregator_type, use_fp16=use_fp16)) | |
# hidden layers | |
for i in range(n_layers - 1): | |
self.layers.append(SAGEConv(n_hidden, n_hidden, aggregator_type, use_fp16=use_fp16)) | |
# output layer | |
self.layers.append(SAGEConv(n_hidden, n_classes, aggregator_type, use_fp16=use_fp16)) # activation None | |
def forward(self, graph, inputs): | |
h = self.dropout(inputs) | |
for l, layer in enumerate(self.layers): | |
h = layer(graph, h) | |
if l != len(self.layers) - 1: | |
h = self.activation(h) | |
h = self.dropout(h) | |
return h | |
def evaluate(model, graph, features, labels, nid): | |
model.eval() | |
with torch.no_grad(): | |
logits = model(graph, features) | |
logits = logits[nid] | |
labels = labels[nid] | |
_, indices = torch.max(logits, dim=1) | |
correct = torch.sum(indices == labels) | |
return correct.item() * 1.0 / len(labels) | |
def main(args): | |
# load and preprocess dataset | |
data = load_data(args) | |
g = data[0] | |
features = g.ndata['feat'] | |
labels = g.ndata['label'] | |
train_mask = g.ndata['train_mask'] | |
val_mask = g.ndata['val_mask'] | |
test_mask = g.ndata['test_mask'] | |
in_feats = features.shape[1] | |
n_classes = data.num_classes | |
n_edges = data.graph.number_of_edges() | |
print("""----Data statistics------' | |
#Edges %d | |
#Classes %d | |
#Train samples %d | |
#Val samples %d | |
#Test samples %d""" % | |
(n_edges, n_classes, | |
train_mask.int().sum().item(), | |
val_mask.int().sum().item(), | |
test_mask.int().sum().item())) | |
if args.gpu < 0: | |
cuda = False | |
else: | |
cuda = True | |
torch.cuda.set_device(args.gpu) | |
features = features.cuda() | |
labels = labels.cuda() | |
train_mask = train_mask.cuda() | |
val_mask = val_mask.cuda() | |
test_mask = test_mask.cuda() | |
print("use cuda:", args.gpu) | |
train_nid = train_mask.nonzero().squeeze() | |
val_nid = val_mask.nonzero().squeeze() | |
test_nid = test_mask.nonzero().squeeze() | |
# graph preprocess and calculate normalization factor | |
g = dgl.remove_self_loop(g) | |
n_edges = g.number_of_edges() | |
if cuda: | |
g = g.int().to(args.gpu) | |
# create GraphSAGE model | |
model = GraphSAGE(in_feats, | |
args.n_hidden, | |
n_classes, | |
args.n_layers, | |
F.relu, | |
args.dropout, | |
args.aggregator_type, | |
args.fp16) | |
if cuda: | |
model.cuda() | |
if args.fp16: | |
from torch.cuda.amp import GradScaler, autocast | |
# use optimizer | |
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) | |
#if args.fp16: | |
# scaler = GradScaler() | |
# initialize graph | |
dur = [] | |
for epoch in range(args.n_epochs): | |
model.train() | |
if epoch >= 3: | |
t0 = time.time() | |
optimizer.zero_grad() | |
# forward | |
if args.fp16: | |
with autocast(): | |
logits = model(g, features) | |
loss = F.cross_entropy(logits[train_nid], labels[train_nid]) | |
else: | |
logits = model(g, features) | |
loss = F.cross_entropy(logits[train_nid], labels[train_nid]) | |
#if args.fp16: | |
# scaler.scale(loss).backward() | |
# scaler.step(optimizer) | |
# scaler.update() | |
#else: | |
loss.backward() | |
optimizer.step() | |
if epoch >= 3: | |
dur.append(time.time() - t0) | |
acc = evaluate(model, g, features, labels, val_nid) | |
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | " | |
"ETputs(KTEPS) {:.2f} | mem {:.2f} MB".format(epoch, np.mean(dur), loss.item(), | |
acc, n_edges / np.mean(dur) / 1000, torch.cuda.max_memory_allocated() / 1024 / 1024)) | |
print() | |
acc = evaluate(model, g, features, labels, test_nid) | |
print("Test Accuracy {:.4f}".format(acc)) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='GraphSAGE') | |
register_data_args(parser) | |
parser.add_argument("--dropout", type=float, default=0.5, | |
help="dropout probability") | |
parser.add_argument("--gpu", type=int, default=-1, | |
help="gpu") | |
parser.add_argument("--lr", type=float, default=1e-2, | |
help="learning rate") | |
parser.add_argument("--fp16", action='store_true') | |
parser.add_argument("--n-epochs", type=int, default=200, | |
help="number of training epochs") | |
parser.add_argument("--n-hidden", type=int, default=512, | |
help="number of hidden gcn units") | |
parser.add_argument("--n-layers", type=int, default=1, | |
help="number of hidden gcn layers") | |
parser.add_argument("--weight-decay", type=float, default=5e-4, | |
help="Weight for L2 loss") | |
parser.add_argument("--aggregator-type", type=str, default="gcn", | |
help="Aggregator type: mean/gcn/pool/lstm") | |
args = parser.parse_args() | |
print(args) | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment