Skip to content

Instantly share code, notes, and snippets.

@sizhky
Created April 4, 2020 05:30
Show Gist options
  • Save sizhky/9b31b8ce263422fc7fa861ef6dddad9e to your computer and use it in GitHub Desktop.
Save sizhky/9b31b8ce263422fc7fa861ef6dddad9e to your computer and use it in GitHub Desktop.
file name says it all
import time
info = lambda report: '\t'.join([f'{k}: {v:.3f}' for k,v in report.items()])
def report_metrics(pos, **report):
elapsed = time.time()-start
end = report.pop('end','\n')
elapsed = '\t({:.2f}s - {:.2f}s remaining)'.format(time.time()-start, ((n_epochs-pos)/pos)*elapsed)
current_iteration = f'EPOCH: {pos:.3f}\t'
print(current_iteration + info(report) + elapsed, end=end)
def save_t(model, fpath, silent=True):
makedir(parent(fpath))
torch.save(model, fpath)
if not silent: logger.info(f'saved model to {fpath}')
def load_t(fpath, device='cpu'): return torch.load(fpath, map_location='cpu')
#=========================#=========================#=========================
from imgaug import augmenters as iaa
seq = iaa.Sequential([
iaa.geometric.Affine(translate_px=(1,10),
rotate=(-4,4),
shear=(-4,4),
cval=255,
backend='cv2'),
iaa.Sometimes(0.2, [iaa.blur.GaussianBlur((0, 1.0))]),
iaa.Sometimes(0.2, [iaa.LinearContrast((0.75, 1.5))]),
])
#=========================#=========================#=========================
def collate(batch):
batch = seq(images=batch)
batch = torch.cat([torch.tensor(im[None,None]) for im in batch])
return 1 - batch.float()/255, 1 - (batch/255).long()
model = AutoEncoder(n_classes=2).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()
#=========================#=========================#=========================
start = time.time()
n_epochs = 4
best_loss = 1000
for ex in range(n_epochs):
N = len(trn_dl)
for bx, (x, y) in enumerate(trn_dl):
x, y = x.to(device), y.to(device)
_y = model(x)
loss = loss_fn(_y, y.squeeze(dim=1))
loss.backward()
optimizer.step()
if loss < best_loss:
save_t(model, 'models/0.1.pth')
best_loss = loss
if bx%300 == 0: report_metrics(ex + ((bx+1)/N), loss=loss, best_loss=best_loss, end='\n')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment