Created
April 4, 2020 05:30
-
-
Save sizhky/9b31b8ce263422fc7fa861ef6dddad9e to your computer and use it in GitHub Desktop.
file name says it all
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
info = lambda report: '\t'.join([f'{k}: {v:.3f}' for k,v in report.items()]) | |
def report_metrics(pos, **report): | |
elapsed = time.time()-start | |
end = report.pop('end','\n') | |
elapsed = '\t({:.2f}s - {:.2f}s remaining)'.format(time.time()-start, ((n_epochs-pos)/pos)*elapsed) | |
current_iteration = f'EPOCH: {pos:.3f}\t' | |
print(current_iteration + info(report) + elapsed, end=end) | |
def save_t(model, fpath, silent=True): | |
makedir(parent(fpath)) | |
torch.save(model, fpath) | |
if not silent: logger.info(f'saved model to {fpath}') | |
def load_t(fpath, device='cpu'): return torch.load(fpath, map_location='cpu') | |
#=========================#=========================#========================= | |
from imgaug import augmenters as iaa | |
seq = iaa.Sequential([ | |
iaa.geometric.Affine(translate_px=(1,10), | |
rotate=(-4,4), | |
shear=(-4,4), | |
cval=255, | |
backend='cv2'), | |
iaa.Sometimes(0.2, [iaa.blur.GaussianBlur((0, 1.0))]), | |
iaa.Sometimes(0.2, [iaa.LinearContrast((0.75, 1.5))]), | |
]) | |
#=========================#=========================#========================= | |
def collate(batch): | |
batch = seq(images=batch) | |
batch = torch.cat([torch.tensor(im[None,None]) for im in batch]) | |
return 1 - batch.float()/255, 1 - (batch/255).long() | |
model = AutoEncoder(n_classes=2).to(device) | |
optimizer = optim.Adam(model.parameters(), lr=1e-4) | |
loss_fn = nn.CrossEntropyLoss() | |
#=========================#=========================#========================= | |
start = time.time() | |
n_epochs = 4 | |
best_loss = 1000 | |
for ex in range(n_epochs): | |
N = len(trn_dl) | |
for bx, (x, y) in enumerate(trn_dl): | |
x, y = x.to(device), y.to(device) | |
_y = model(x) | |
loss = loss_fn(_y, y.squeeze(dim=1)) | |
loss.backward() | |
optimizer.step() | |
if loss < best_loss: | |
save_t(model, 'models/0.1.pth') | |
best_loss = loss | |
if bx%300 == 0: report_metrics(ex + ((bx+1)/N), loss=loss, best_loss=best_loss, end='\n') | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment