Skip to content

Instantly share code, notes, and snippets.

@m-klasen
Last active January 31, 2020 10:30
Show Gist options
  • Save m-klasen/d6f62a76dabbc404c2fbe3da6033a524 to your computer and use it in GitHub Desktop.
Save m-klasen/d6f62a76dabbc404c2fbe3da6033a524 to your computer and use it in GitHub Desktop.
class Loss_combine(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input, target,reduction='mean'):
x1,x2,x3 = input
x1,x2,x3 = x1.half(),x2.half(),x3.half()
y = target.long()
return 0.7*F.cross_entropy(x1,y[:,0],reduction=reduction)+ 0.1*F.cross_entropy(x2,y[:,1],reduction=reduction) + \
0.1*F.cross_entropy(x3,y[:,2],reduction=reduction)
from fastai.vision import *
from numbers import Integral
def random_strat_splitter(y, train_size:int=1, seed:int=1):
from sklearn.model_selection import StratifiedShuffleSplit
sss = StratifiedShuffleSplit(n_splits=1, train_size=train_size, random_state=seed)
idx = list(sss.split(np.arange(len(y)), y))[0]
return idx[0],idx[1]
class MultiTfmLabelList(LabelList):
def __init__(self, x:ItemList, y:ItemList, tfms:TfmList=None, tfm_y:bool=False, K=2, **kwargs):
"K: number of transformed samples generated per item"
self.x,self.y,self.tfm_y,self.K = x,y,tfm_y,K
self.y.x = x
self.item=None
self.transform(tfms, **kwargs)
def __getitem__(self,idxs:Union[int, np.ndarray])->'LabelList':
"return a single (x, y) if `idxs` is an integer or a new `LabelList` object if `idxs` is a range."
idxs = try_int(idxs)
if isinstance(idxs, Integral):
if self.item is None: x,y = self.x[idxs],self.y[idxs]
else: x,y = self.item ,0
if self.tfms or self.tfmargs:
x = [x.apply_tfms(self.tfms, **self.tfmargs) for _ in range(self.K)]
if hasattr(self, 'tfms_y') and self.tfm_y and self.item is None:
y = y.apply_tfms(self.tfms_y, **{**self.tfmargs_y, 'do_resolve':False})
if y is None: y=0
return x,y
else: return self.new(self.x[idxs], self.y[idxs])
def MultiCollate(batch):
batch = to_data(batch)
if isinstance(batch[0][0],list): batch = [[torch.stack(s[0]),s[1]] for s in batch]
return torch.utils.data.dataloader.default_collate(batch)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def mixmatch(learn: Learner, ulist: ItemList, num_workers:int=None,
K: int = 2, T: float = .5, α: float = .75, λ: float = 100) -> Learner:
labeled_data = learn.data
if num_workers is None: num_workers = 1
labeled_data.train_dl.num_workers = num_workers
bs = labeled_data.train_dl.batch_size
tfms = [labeled_data.train_ds.tfms, labeled_data.valid_ds.tfms]
ulist = ulist.split_none()
ulist.train._label_list = partial(MultiTfmLabelList, K=K)
train_ul = ulist.label_empty().train # Train unlabeled Labelist
valid_ll = learn.data.label_list.valid # Valid labeled Labelist
udata = (LabelLists('.', train_ul, valid_ll)
.transform(tfms)
.databunch(bs=min(bs, len(train_ul)),val_bs=min(bs * 2, len(valid_ll)),
num_workers=num_workers,dl_tfms=learn.data.dl_tfms,device=device,
collate_fn=MultiCollate)
.normalize(learn.data.stats))
learn.data = udata
learn.callback_fns.append(partial(MixMatchCallback, labeled_data=labeled_data, T=T, α=α, λ=λ))
return learn
Learner.mixmatch = mixmatch
def _mixup(x1, y1, x2, y2, α=.4):
β = np.random.beta(α, α)
β = max(β, 1 - β)
x = β * x1 + (1 - β) * x2
y = β * y1 + (1 - β) * y2
return x, y
def rand_bbox(last_input_size, λ):
'''lambd is always between .5 and 1'''
W = last_input_size[2]
H = last_input_size[3]
cut_rat = torch.sqrt(torch.tensor(1.) - λ) # 0. - .707
cut_w = (W * cut_rat).to(torch.int)
cut_h = (H * cut_rat).to(torch.int)
# uniform
cx = torch.LongTensor(1).random_(0,W).cuda()
cy = torch.LongTensor(1).random_(0,H).cuda()
bbx1 = torch.clamp(cx - cut_w // 2, 0, W)
bby1 = torch.clamp(cy - cut_h // 2, 0, H)
bbx2 = torch.clamp(cx + cut_w // 2, 0, W)
bby2 = torch.clamp(cy + cut_h // 2, 0, H)
return bbx1, bby1, bbx2, bby2
def sharpen(x, T=0.5):
p_target = x**(1. / T)
return p_target / p_target.sum(dim=1, keepdims=True)
def drop_cb_fn(learn, cb_name:str)->None:
cbs = []
for cb in learn.callback_fns:
if isinstance(cb, functools.partial): cbn = cb.func.__name__
else: cbn = cb.__name__
if cbn != cb_name: cbs.append(cb)
learn.callback_fns = cbs
class MatchMixLoss(Module):
"Adapt the loss function `crit` to go with MatchMix."
def __init__(self, crit=None, reduction='mean', λ=100):
super().__init__()
if crit is None: crit = nn.CrossEntropyLoss()
if hasattr(crit, 'reduction'):
self.crit = crit
self.old_red = crit.reduction
setattr(self.crit, 'reduction', 'none')
else:
self.crit = partial(crit, reduction='none')
self.old_crit = crit
self.reduction = reduction
self.λ = λ
def forward(self, output, target, bs=None):
if bs is None:
d = self.crit(output, target)
if self.reduction == 'mean': return d.mean()
elif self.reduction == 'sum': return d.sum()
else:
p1 = target[:,:168].argmax(-1).reshape(-1,1)
p2 = target[:,168:168+11].argmax(-1).reshape(-1,1)
p3 = target[:,168+11:168+11+7].argmax(-1).reshape(-1,1)
target_d = torch.cat((p1,p2,p3),dim=1)
#Labeled data
Lx = self.crit([output[0][:bs],output[1][:bs],output[2][:bs]],target_d[:bs,0:3].long()).mean()
self.Lx = Lx.item()
#Unlabeled Data
u_p1 = torch.softmax(output[0][bs:],dim=1).half()
u_p2 = torch.softmax(output[1][bs:],dim=1).half()
u_p3 = torch.softmax(output[2][bs:],dim=1).half()
Lu = 0.9*F.mse_loss(u_p1,target[bs:,:168].half()) \
+0.06*F.mse_loss(u_p2,target[bs:,168:168+11].half()) \
+0.04*F.mse_loss(u_p3,target[bs:,168+11:168+11+7].half())
self.Lu = (Lu * self.λ).item()
d = Lx + Lu * self.λ
if self.reduction == 'mean': return d.mean()
elif self.reduction == 'sum': return d.sum()
return d
def get_old(self):
if hasattr(self, 'old_crit'): return self.old_crit
elif hasattr(self, 'old_red'):
setattr(self.crit, 'reduction', self.old_red)
return self.crit
class MixMatchCallback(LearnerCallback):
_order = -20
def __init__(self,
learn: Learner,
labeled_data: DataBunch,
T: float = .5,
K: int = 2,
α: float = .75,
λ: float = 100):
super().__init__(learn)
self.learn, self.T, self.K, self.α, self.λ = learn, T, K, α, λ
self.labeled_dl = labeled_data.train_dl
self.n_classes = [168,11,7]
self.c = 3
self.labeled_data = labeled_data
def on_train_begin(self, n_epochs, **kwargs):
self.learn.loss_func = MatchMixLoss(crit=self.learn.loss_func, λ=self.λ)
self.ldliter = iter(self.labeled_dl)
self.smoothLx, self.smoothLu = SmoothenValue(0.98), SmoothenValue(0.98)
self.recorder.add_metric_names(["train_Lx", "train_Lu*λ"])
self.it = 0
print('labeled dataset : {:13,} samples'.format(len(self.labeled_data.train_ds)))
print('unlabeled dataset : {:13,} samples'.format(len(self.learn.data.train_ds)))
total_samples = n_epochs *len(self.learn.data.train_dl) *\
self.learn.data.train_dl.batch_size * (self.K + 1)
print('total train samples : {:13,} samples'.format(total_samples))
def on_batch_begin(self, last_input, last_target, train, **kwargs):
if not train: return
try:
Xx, Xy = next(self.ldliter) # Xx already augmented
except StopIteration:
self.ldliter = iter(self.labeled_dl)
Xx, Xy = next(self.ldliter) # Xx already augmented
# LABELED
bs = len(Xx)
#pb [3,[168,11,7]]
pb_1 = torch.eye(self.n_classes[0])[Xy[:,0].to(torch.long)].to(device)
pb_2 = torch.eye(self.n_classes[1])[Xy[:,1].to(torch.long)].to(device)
pb_3 = torch.eye(self.n_classes[2])[Xy[:,2].to(torch.long)].to(device)
pb = torch.cat((pb_1,pb_2,pb_3),dim=1)
# UNLABELED
shape = list(last_input.size()[2:])
Ux = last_input.view([-1] + shape) # Ux already augmented (K items)
with torch.no_grad():
p1,p2,p3 = self.learn.model(last_input[:, 0])
p4,p5,p6= self.learn.model(last_input[:, 1])
p1 = sharpen(torch.softmax(torch.stack((p1,p4),dim=1),dim=2).mean(dim=1),T=self.T)
p2 = sharpen(torch.softmax(torch.stack((p2,p5),dim=1),dim=2).mean(dim=1),T=self.T)
p3 = sharpen(torch.softmax(torch.stack((p3,p6),dim=1),dim=2).mean(dim=1),T=self.T)
#Stack one-hot categories side by side for ease of data-handling further down
Uy = torch.cat((p1,p2,p3),dim=1)
qb = Uy.repeat(1, 2).view((-1, Uy.size(-1)))
Wx = torch.cat((Xx.half(), Ux.half()), dim=0)
Wy = torch.cat((pb.half(), qb), dim=0)
shuffle = torch.randperm(Wx.shape[0])
if np.random.rand()<0.5:
#MIX
mixed_input, mixed_target = _mixup(Wx,Wy,Wx[shuffle],Wy[shuffle],α=self.α)
else:
#CUT
β = np.random.beta(1.,1.,Wy.size(0))
β = last_input.new(β)
for i in range(Wy.size(0)):
bbx1, bby1, bbx2, bby2 = rand_bbox(Wx.size(), β[i])
Wx[i,..., bby1:bby2, bbx1:bbx2] = Wx[shuffle[i],..., bby1:bby2, bbx1:bbx2]
# adjust lambda to exactly match pixel ratio
β[i] = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (last_input.size()[-1] * last_input.size()[-2]))
β = β[:,None]
y = β * Wy + (1 - β) * Wy[shuffle]
mixed_target = y
mixed_input = Wx
return {"last_input": mixed_input, "last_target": (mixed_target, bs)}
def on_batch_end(self, train, **kwargs):
if not train: return
self.smoothLx.add_value(self.learn.loss_func.Lx)
self.smoothLu.add_value(self.learn.loss_func.Lu)
self.it += 1
def on_epoch_end(self, last_metrics, **kwargs):
return add_metrics(last_metrics, [self.smoothLx.smooth, self.smoothLu.smooth])
def on_train_end(self, **kwargs):
"""At the end of training, loss_func and data are returned to their original values,
and this calleback is removed"""
self.learn.loss_func = self.learn.loss_func.get_old()
self.learn.data = self.labeled_data
drop_cb_fn(self.learn, 'MixMatchCallback')
#Example of usage
train_size = 20000 #Number of Labeled Images
tfms = get_transforms(...#Your Tfms)
list = ImageList.from_df(df,path="",folder="data/train", suffix='.png', cols='image_id', convert_mode='L').split_by_rand_pct(valid_pct=0.2).label_empty()
y_train = list.train.y.items
y_valid = list.valid.y.items
l_idx, u_idx = random_strat_splitter(y_train, train_size=train_size, seed=SEED)
data = (ItemLists('.',
list.train[l_idx], # labeled train (in this case a subset of the original)
list.valid) # labeled valid
.label_from_df(cols=['grapheme_root','vowel_diacritic','consonant_diacritic'])
.transform(tfms, size=SIZE)
.databunch(bs=BS)).normalize(stats)
ulist = list.train[u_idx]
learn = Learner(
data,
model,
loss_func=Loss_combine(),
opt_func=Ranger,
metrics=[Metric_grapheme(),Metric_vowel(),Metric_consonant(),Metric_tot()]).mixmatch(ulist, α=.4, λ=75)
learn.clip_grad = 1.0
learn.to_fp16()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment