Last active
September 5, 2021 15:11
-
-
Save thsunkid/28731eddd4192cb10f8441e338d84d35 to your computer and use it in GitHub Desktop.
XLA compilation error for TPU sample code
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import math | |
import time | |
import random | |
import numpy as np | |
import pandas as pd | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.optim import Adam, SGD | |
from torch.nn.parameter import Parameter | |
from torch.utils.data import DataLoader, Dataset | |
from torch.optim.lr_scheduler import CosineAnnealingLR | |
import torch_xla | |
import torch_xla.distributed.parallel_loader as pl | |
import torch_xla.core.xla_model as xm | |
import torch_xla.distributed.xla_multiprocessing as xmp | |
from tqdm.auto import tqdm | |
from sklearn.metrics import roc_auc_score | |
from sklearn.model_selection import StratifiedKFold | |
import timm | |
import albumentations | |
from albumentations.pytorch import ToTensorV2 | |
from nnAudio.Spectrogram import CQT1992v2 | |
OUTPUT_DIR = './' | |
# ==================================================== | |
# CFG | |
# ==================================================== | |
class CFG: | |
num_workers=4 | |
model_name='tf_efficientnet_b7_ns' | |
epochs=1 | |
T_max=3 | |
lr=1e-4 | |
min_lr=1e-6 | |
batch_size=48 | |
weight_decay=1e-6 | |
gradient_accumulation_steps=1 | |
max_grad_norm=1000 | |
qtransform_params={"sr": 2048, "fmin": 20, "fmax": 1024, "hop_length": 32, "bins_per_octave": 8, "n_bins":None} | |
target_size=1 | |
train=True | |
# ==================================================== | |
# Utils | |
# ==================================================== | |
def get_score(y_true, y_pred): | |
score = roc_auc_score(y_true, y_pred) | |
return score | |
def seed_torch(seed=42): | |
random.seed(seed) | |
os.environ['PYTHONHASHSEED'] = str(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.backends.cudnn.deterministic = True | |
# ==================================================== | |
# Dataset | |
# ==================================================== | |
class TrainDataset(Dataset): | |
def __init__(self, transform=None): | |
self.wave_transform = CQT1992v2(**CFG.qtransform_params) | |
self.transform = transform | |
def __len__(self): | |
return 2000 | |
def apply_qtransform(self, waves, transform): | |
waves = np.hstack(waves) | |
waves = waves / np.max(waves) | |
waves = torch.from_numpy(waves).float() | |
image = transform(waves) | |
return image | |
def __getitem__(self, idx): | |
waves = np.random.randn(3,4096) | |
image = self.apply_qtransform(waves, self.wave_transform) | |
image = image.squeeze().numpy() | |
if self.transform: | |
image = self.transform(image=image)['image'] | |
label = torch.tensor(1).float() | |
return image, label | |
# ==================================================== | |
# Transforms | |
# ==================================================== | |
def get_transforms(*, data): | |
if data == 'train': | |
return albumentations.Compose([ | |
ToTensorV2(), | |
]) | |
elif data == 'valid': | |
return albumentations.Compose([ | |
ToTensorV2(), | |
]) | |
# ==================================================== | |
# MODEL | |
# ==================================================== | |
class CustomModel(nn.Module): | |
def __init__(self, cfg, pretrained=False): | |
super().__init__() | |
self.cfg = cfg | |
self.model = timm.create_model(self.cfg.model_name, pretrained=pretrained, in_chans=1) | |
self.n_features = self.model.classifier.in_features | |
self.model.classifier = nn.Linear(self.n_features, self.cfg.target_size) | |
def forward(self, x): | |
output = self.model(x) | |
return output | |
# ==================================================== | |
# Helper functions | |
# ==================================================== | |
class AverageMeter(object): | |
"""Computes and stores the average and current value""" | |
def __init__(self): | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
def loss_fn(outputs, targets): | |
return nn.BCEWithLogitsLoss()(outputs, targets.view(-1, 1)) | |
def train_fn(fold, train_loader, model, criterion, optimizer, epoch, scheduler, device): | |
batch_time = AverageMeter() | |
data_time = AverageMeter() | |
losses = AverageMeter() | |
scores = AverageMeter() | |
# switch to train mode | |
model.train() | |
start = end = time.time() | |
global_step = 0 | |
xm.master_print("Training time ... ") | |
train_loader = tqdm(train_loader) | |
for step, (images, labels) in enumerate(train_loader): | |
# measure data loading time | |
data_time.update(time.time() - end) | |
images = images.to(device) | |
labels = labels.to(device) | |
batch_size = labels.size(0) | |
y_preds = model(images) | |
loss = criterion(y_preds.view(-1), labels) | |
# record loss | |
losses.update(loss.item(), batch_size) | |
if CFG.gradient_accumulation_steps > 1: | |
loss = loss / CFG.gradient_accumulation_steps | |
loss.backward() | |
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm) | |
if (step + 1) % CFG.gradient_accumulation_steps == 0: | |
xm.optimizer_step(optimizer) | |
optimizer.zero_grad() | |
global_step += 1 | |
# measure elapsed time | |
batch_time.update(time.time() - end) | |
end = time.time() | |
return losses.avg | |
def valid_fn(valid_loader, model, criterion, device): | |
batch_time = AverageMeter() | |
data_time = AverageMeter() | |
losses = AverageMeter() | |
scores = AverageMeter() | |
# switch to evaluation mode | |
model.eval() | |
preds = [] | |
start = end = time.time() | |
valid_labels = [] | |
for step, (images, labels) in enumerate(valid_loader): | |
# measure data loading time | |
data_time.update(time.time() - end) | |
images = images.to(device) | |
labels = labels.to(device) | |
batch_size = labels.size(0) | |
# compute loss | |
with torch.no_grad(): | |
y_preds = model(images) | |
xm.mark_step() | |
loss = loss_fn(y_preds, labels) | |
losses.update(loss.item(), batch_size) | |
# record accuracy | |
preds.append(y_preds.sigmoid().to('cpu').numpy()) | |
valid_labels.append(labels.to('cpu').numpy()) | |
if CFG.gradient_accumulation_steps > 1: | |
loss = loss / CFG.gradient_accumulation_steps | |
# measure elapsed time | |
batch_time.update(time.time() - end) | |
end = time.time() | |
preds = np.concatenate(preds) | |
valid_labels = np.concatenate(valid_labels) | |
score = get_score(valid_labels, preds) | |
return losses.avg, preds, score | |
# ==================================================== | |
# Train loop | |
# ==================================================== | |
def train_loop(tid): | |
device = xm.xla_device() | |
train_dataset = TrainDataset(transform=get_transforms(data='train')) | |
valid_dataset = TrainDataset(transform=get_transforms(data='train')) | |
train_sampler = torch.utils.data.distributed.DistributedSampler( | |
train_dataset, | |
num_replicas=xm.xrt_world_size(), | |
rank=xm.get_ordinal(), | |
shuffle=True) | |
valid_sampler = torch.utils.data.distributed.DistributedSampler( | |
valid_dataset, | |
num_replicas=xm.xrt_world_size(), | |
rank=xm.get_ordinal(), | |
shuffle=False, | |
) | |
train_loader = DataLoader(train_dataset, | |
batch_size=CFG.batch_size, | |
sampler=train_sampler, | |
num_workers=CFG.num_workers, pin_memory=True, drop_last=True) | |
valid_loader = DataLoader(valid_dataset, | |
batch_size=CFG.batch_size, | |
sampler=valid_sampler, | |
num_workers=CFG.num_workers, pin_memory=True, drop_last=False) | |
xm.master_print("Dataloader") | |
# ==================================================== | |
# model & optimizer | |
# ==================================================== | |
model = CustomModel(CFG, pretrained=True) | |
model.to(device) | |
optimizer = Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay, amsgrad=False) | |
scheduler = CosineAnnealingLR(optimizer, T_max=CFG.T_max, eta_min=CFG.min_lr, last_epoch=-1) | |
criterion = nn.BCEWithLogitsLoss() | |
best_score = 0. | |
best_loss = np.inf | |
fold = 0 | |
for epoch in range(CFG.epochs): | |
start_time = time.time() | |
# train | |
para_loader = pl.ParallelLoader(train_loader, [device]) | |
avg_loss = train_fn(fold, para_loader.per_device_loader(device), model, criterion, optimizer, epoch, scheduler, device) | |
# eval | |
para_loader = pl.ParallelLoader(valid_loader, [device]) | |
avg_val_loss, preds, score = valid_fn(para_loader.per_device_loader(device), model, criterion, device) | |
scheduler.step() | |
elapsed = time.time() - start_time | |
if score > best_score: | |
best_score = score | |
xm.save(model.state_dict(), OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best_score.pth') | |
if avg_val_loss < best_loss: | |
best_loss = avg_val_loss | |
xm.save(model.state_dict(), OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best_score.pth') | |
return best_score, best_loss | |
if __name__ == "__main__": | |
seed_torch() | |
torch.set_default_tensor_type('torch.FloatTensor') | |
xmp.spawn(train_loop, args=(), nprocs=None) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Below is full log of the error.
It's bug-free if you use Efficientnet b6 instead. I suspect it is a memory-related issue.