Created
October 30, 2018 15:59
-
-
Save tejus-gupta/d12ebfe5374cdaa937c30e794d925cf4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
import yaml | |
import time | |
import shutil | |
import torch | |
import random | |
import argparse | |
import datetime | |
import numpy as np | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.models as models | |
from torch.utils import data | |
from tqdm import tqdm | |
from ptsemseg.models import get_model | |
from ptsemseg.loss import get_loss_function | |
from ptsemseg.loader import get_loader | |
from ptsemseg.utils import get_logger | |
from ptsemseg.metrics import runningScore, averageMeter | |
from ptsemseg.augmentations import get_composed_augmentations | |
from ptsemseg.schedulers import get_scheduler | |
from ptsemseg.optimizers import get_optimizer | |
from tensorboardX import SummaryWriter | |
import sys | |
sys.path.append('/home/tejus/lane-seg-experiments/Segmentation/') | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
from torchvision import datasets, models, transforms | |
from datasets.kitti.config import CONFIG | |
from datasets.kitti.kitti_loader import kittiLoader | |
from datasets.tusimple.tusimple_loader import tusimpleLoader | |
from CAN import CAN | |
from datasets.tusimple.augmentations import * | |
import torchvision | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
import matplotlib.pyplot as plt | |
from metrics import runningScore | |
from datetime import datetime | |
import math | |
# Definitions | |
TRAIN_BATCH = 3 | |
VAL_BATCH = 4 | |
resume_training = False | |
checkpoint_dir = '/home/tejus/lane-seg-experiments/Segmentation/CAN_logger/context_and_LFE/best_val_model.pkl' | |
run_id = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') | |
logdir = os.path.join('runs/' , str(run_id)) | |
writer = SummaryWriter(log_dir=logdir) | |
print('RUNDIR: {}'.format(logdir)) | |
logger = get_logger(logdir) | |
logger.info('Let the party begin | Dilated convolutions') | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
torch.manual_seed(102) | |
# Network definition | |
net = CAN() | |
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=0.0001) | |
# optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr = 0.0001, momentum=0.9) # 0.00001 | |
# loss_fn = nn.BCEWithLogitsLoss() | |
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience = 2, verbose = True, min_lr = 0.000001) | |
loss_fn = nn.CrossEntropyLoss() | |
net.to(device) | |
if not resume_training: | |
checkpoint = torch.load(checkpoint_dir) | |
net.load_state_dict(checkpoint["model_state"],strict = False) | |
else: | |
checkpoint = torch.load(checkpoint_dir) | |
net.load_state_dict(checkpoint["model_state"],strict = False) | |
optimizer.load_state_dict(checkpoint["optimizer_state"]) | |
scheduler.load_state_dict(checkpoint["scheduler_state"]) | |
start_iter = checkpoint["epoch"] | |
logger.info( | |
"Loaded checkpoint '{}' (epoch {})".format( | |
checkpoint_dir, checkpoint["epoch"] | |
) | |
) | |
# Set up initialization weights for dilation layer as described in the context-aggregation paper | |
params = net.state_dict() | |
dilated_conv_layers = [36, 39, 42, 45, 48, 51, 54] | |
for layer_idx in dilated_conv_layers: | |
w = params['features.'+str(layer_idx)+'.weight'] | |
b = params['features.'+str(layer_idx)+'.bias'] | |
w.fill_(0) | |
for i in range(w.shape[0]): | |
w[i,i,1,1] = 1 | |
#print(w) | |
b.fill_(0) | |
params['features.'+str(layer_idx)+'.weight'] = w | |
params['features.'+str(layer_idx)+'.weight'] = b | |
#torch.save(net.state_dict(), 'test_identity.wts') | |
layer_idx = 56 | |
w = params['features.'+str(layer_idx)+'.weight'] | |
w.fill_(0) | |
for i in range(w.shape[0]): | |
w[i,i,0,0] = 1 | |
params['features.'+str(layer_idx)+'.weight'] = w | |
net.train() | |
### freeze weights of frontend network | |
i = 0 | |
for k, v in params.items(): | |
v.requires_grad = False | |
i += 1 | |
if i == 32: | |
break | |
### | |
augmentations = Compose([RandomRotate(5), RandomHorizontallyFlip()]) | |
train_dataset = tusimpleLoader('/home/tejus/Downloads/train_set/', split="train", augmentations=augmentations) | |
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=TRAIN_BATCH, shuffle=True, num_workers=TRAIN_BATCH, pin_memory=True) | |
val_dataset = tusimpleLoader('/home/tejus/Downloads/train_set/', split="val", augmentations=None) | |
valloader = torch.utils.data.DataLoader(val_dataset, batch_size=VAL_BATCH, shuffle=False, num_workers=VAL_BATCH, pin_memory=True) | |
running_metrics_val = runningScore(2) | |
best_val_loss = math.inf | |
val_loss = 0 | |
ctr = 0 | |
best_iou=-100 | |
val_loss_meter = averageMeter() | |
time_meter = averageMeter() | |
## compute val loss for pretrained weights | |
val_loss = 0 | |
net.eval() | |
with torch.no_grad(): | |
for i, data in enumerate(valloader): | |
print(i) | |
if i>10: | |
break | |
imgs, labels = data | |
imgs, labels = imgs.to(device), labels.to(device) | |
out = net(imgs) | |
loss = loss_fn(out, labels) | |
pred = out.data.max(1)[1] | |
running_metrics_val.update(pred.cpu().numpy(),labels.cpu().numpy()) | |
val_loss_meter.update(loss.item()) | |
val_loss += loss.item() | |
print("val_loss = ", val_loss) | |
running_loss = 0 | |
net.train() | |
for i, data in enumerate(trainloader): | |
print(i) | |
if i>10: | |
break | |
imgs, labels = data | |
imgs, labels = imgs.to(device), labels.to(device) | |
out = net(imgs) | |
loss = loss_fn(out, labels) | |
loss.backward() | |
optimizer.step() | |
optimizer.zero_grad() | |
running_loss += loss.item() | |
print("running loss = ", running_loss) | |
val_loss = 0 | |
net.eval() | |
with torch.no_grad(): | |
for i, data in enumerate(valloader): | |
print(i) | |
if i>10: | |
break | |
imgs, labels = data | |
imgs, labels = imgs.to(device), labels.to(device) | |
out = net(imgs) | |
loss = loss_fn(out, labels) | |
pred = out.data.max(1)[1] | |
running_metrics_val.update(pred.cpu().numpy(),labels.cpu().numpy()) | |
val_loss_meter.update(loss.item()) | |
val_loss += loss.item() | |
print("val_loss = ", val_loss) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment