Last active
January 31, 2020 10:30
-
-
Save m-klasen/d6f62a76dabbc404c2fbe3da6033a524 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Loss_combine(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, input, target,reduction='mean'): | |
x1,x2,x3 = input | |
x1,x2,x3 = x1.half(),x2.half(),x3.half() | |
y = target.long() | |
return 0.7*F.cross_entropy(x1,y[:,0],reduction=reduction)+ 0.1*F.cross_entropy(x2,y[:,1],reduction=reduction) + \ | |
0.1*F.cross_entropy(x3,y[:,2],reduction=reduction) | |
from fastai.vision import * | |
from numbers import Integral | |
def random_strat_splitter(y, train_size:int=1, seed:int=1): | |
from sklearn.model_selection import StratifiedShuffleSplit | |
sss = StratifiedShuffleSplit(n_splits=1, train_size=train_size, random_state=seed) | |
idx = list(sss.split(np.arange(len(y)), y))[0] | |
return idx[0],idx[1] | |
class MultiTfmLabelList(LabelList): | |
def __init__(self, x:ItemList, y:ItemList, tfms:TfmList=None, tfm_y:bool=False, K=2, **kwargs): | |
"K: number of transformed samples generated per item" | |
self.x,self.y,self.tfm_y,self.K = x,y,tfm_y,K | |
self.y.x = x | |
self.item=None | |
self.transform(tfms, **kwargs) | |
def __getitem__(self,idxs:Union[int, np.ndarray])->'LabelList': | |
"return a single (x, y) if `idxs` is an integer or a new `LabelList` object if `idxs` is a range." | |
idxs = try_int(idxs) | |
if isinstance(idxs, Integral): | |
if self.item is None: x,y = self.x[idxs],self.y[idxs] | |
else: x,y = self.item ,0 | |
if self.tfms or self.tfmargs: | |
x = [x.apply_tfms(self.tfms, **self.tfmargs) for _ in range(self.K)] | |
if hasattr(self, 'tfms_y') and self.tfm_y and self.item is None: | |
y = y.apply_tfms(self.tfms_y, **{**self.tfmargs_y, 'do_resolve':False}) | |
if y is None: y=0 | |
return x,y | |
else: return self.new(self.x[idxs], self.y[idxs]) | |
def MultiCollate(batch): | |
batch = to_data(batch) | |
if isinstance(batch[0][0],list): batch = [[torch.stack(s[0]),s[1]] for s in batch] | |
return torch.utils.data.dataloader.default_collate(batch) | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
def mixmatch(learn: Learner, ulist: ItemList, num_workers:int=None, | |
K: int = 2, T: float = .5, α: float = .75, λ: float = 100) -> Learner: | |
labeled_data = learn.data | |
if num_workers is None: num_workers = 1 | |
labeled_data.train_dl.num_workers = num_workers | |
bs = labeled_data.train_dl.batch_size | |
tfms = [labeled_data.train_ds.tfms, labeled_data.valid_ds.tfms] | |
ulist = ulist.split_none() | |
ulist.train._label_list = partial(MultiTfmLabelList, K=K) | |
train_ul = ulist.label_empty().train # Train unlabeled Labelist | |
valid_ll = learn.data.label_list.valid # Valid labeled Labelist | |
udata = (LabelLists('.', train_ul, valid_ll) | |
.transform(tfms) | |
.databunch(bs=min(bs, len(train_ul)),val_bs=min(bs * 2, len(valid_ll)), | |
num_workers=num_workers,dl_tfms=learn.data.dl_tfms,device=device, | |
collate_fn=MultiCollate) | |
.normalize(learn.data.stats)) | |
learn.data = udata | |
learn.callback_fns.append(partial(MixMatchCallback, labeled_data=labeled_data, T=T, α=α, λ=λ)) | |
return learn | |
Learner.mixmatch = mixmatch | |
def _mixup(x1, y1, x2, y2, α=.4): | |
β = np.random.beta(α, α) | |
β = max(β, 1 - β) | |
x = β * x1 + (1 - β) * x2 | |
y = β * y1 + (1 - β) * y2 | |
return x, y | |
def rand_bbox(last_input_size, λ): | |
'''lambd is always between .5 and 1''' | |
W = last_input_size[2] | |
H = last_input_size[3] | |
cut_rat = torch.sqrt(torch.tensor(1.) - λ) # 0. - .707 | |
cut_w = (W * cut_rat).to(torch.int) | |
cut_h = (H * cut_rat).to(torch.int) | |
# uniform | |
cx = torch.LongTensor(1).random_(0,W).cuda() | |
cy = torch.LongTensor(1).random_(0,H).cuda() | |
bbx1 = torch.clamp(cx - cut_w // 2, 0, W) | |
bby1 = torch.clamp(cy - cut_h // 2, 0, H) | |
bbx2 = torch.clamp(cx + cut_w // 2, 0, W) | |
bby2 = torch.clamp(cy + cut_h // 2, 0, H) | |
return bbx1, bby1, bbx2, bby2 | |
def sharpen(x, T=0.5): | |
p_target = x**(1. / T) | |
return p_target / p_target.sum(dim=1, keepdims=True) | |
def drop_cb_fn(learn, cb_name:str)->None: | |
cbs = [] | |
for cb in learn.callback_fns: | |
if isinstance(cb, functools.partial): cbn = cb.func.__name__ | |
else: cbn = cb.__name__ | |
if cbn != cb_name: cbs.append(cb) | |
learn.callback_fns = cbs | |
class MatchMixLoss(Module): | |
"Adapt the loss function `crit` to go with MatchMix." | |
def __init__(self, crit=None, reduction='mean', λ=100): | |
super().__init__() | |
if crit is None: crit = nn.CrossEntropyLoss() | |
if hasattr(crit, 'reduction'): | |
self.crit = crit | |
self.old_red = crit.reduction | |
setattr(self.crit, 'reduction', 'none') | |
else: | |
self.crit = partial(crit, reduction='none') | |
self.old_crit = crit | |
self.reduction = reduction | |
self.λ = λ | |
def forward(self, output, target, bs=None): | |
if bs is None: | |
d = self.crit(output, target) | |
if self.reduction == 'mean': return d.mean() | |
elif self.reduction == 'sum': return d.sum() | |
else: | |
p1 = target[:,:168].argmax(-1).reshape(-1,1) | |
p2 = target[:,168:168+11].argmax(-1).reshape(-1,1) | |
p3 = target[:,168+11:168+11+7].argmax(-1).reshape(-1,1) | |
target_d = torch.cat((p1,p2,p3),dim=1) | |
#Labeled data | |
Lx = self.crit([output[0][:bs],output[1][:bs],output[2][:bs]],target_d[:bs,0:3].long()).mean() | |
self.Lx = Lx.item() | |
#Unlabeled Data | |
u_p1 = torch.softmax(output[0][bs:],dim=1).half() | |
u_p2 = torch.softmax(output[1][bs:],dim=1).half() | |
u_p3 = torch.softmax(output[2][bs:],dim=1).half() | |
Lu = 0.9*F.mse_loss(u_p1,target[bs:,:168].half()) \ | |
+0.06*F.mse_loss(u_p2,target[bs:,168:168+11].half()) \ | |
+0.04*F.mse_loss(u_p3,target[bs:,168+11:168+11+7].half()) | |
self.Lu = (Lu * self.λ).item() | |
d = Lx + Lu * self.λ | |
if self.reduction == 'mean': return d.mean() | |
elif self.reduction == 'sum': return d.sum() | |
return d | |
def get_old(self): | |
if hasattr(self, 'old_crit'): return self.old_crit | |
elif hasattr(self, 'old_red'): | |
setattr(self.crit, 'reduction', self.old_red) | |
return self.crit | |
class MixMatchCallback(LearnerCallback): | |
_order = -20 | |
def __init__(self, | |
learn: Learner, | |
labeled_data: DataBunch, | |
T: float = .5, | |
K: int = 2, | |
α: float = .75, | |
λ: float = 100): | |
super().__init__(learn) | |
self.learn, self.T, self.K, self.α, self.λ = learn, T, K, α, λ | |
self.labeled_dl = labeled_data.train_dl | |
self.n_classes = [168,11,7] | |
self.c = 3 | |
self.labeled_data = labeled_data | |
def on_train_begin(self, n_epochs, **kwargs): | |
self.learn.loss_func = MatchMixLoss(crit=self.learn.loss_func, λ=self.λ) | |
self.ldliter = iter(self.labeled_dl) | |
self.smoothLx, self.smoothLu = SmoothenValue(0.98), SmoothenValue(0.98) | |
self.recorder.add_metric_names(["train_Lx", "train_Lu*λ"]) | |
self.it = 0 | |
print('labeled dataset : {:13,} samples'.format(len(self.labeled_data.train_ds))) | |
print('unlabeled dataset : {:13,} samples'.format(len(self.learn.data.train_ds))) | |
total_samples = n_epochs *len(self.learn.data.train_dl) *\ | |
self.learn.data.train_dl.batch_size * (self.K + 1) | |
print('total train samples : {:13,} samples'.format(total_samples)) | |
def on_batch_begin(self, last_input, last_target, train, **kwargs): | |
if not train: return | |
try: | |
Xx, Xy = next(self.ldliter) # Xx already augmented | |
except StopIteration: | |
self.ldliter = iter(self.labeled_dl) | |
Xx, Xy = next(self.ldliter) # Xx already augmented | |
# LABELED | |
bs = len(Xx) | |
#pb [3,[168,11,7]] | |
pb_1 = torch.eye(self.n_classes[0])[Xy[:,0].to(torch.long)].to(device) | |
pb_2 = torch.eye(self.n_classes[1])[Xy[:,1].to(torch.long)].to(device) | |
pb_3 = torch.eye(self.n_classes[2])[Xy[:,2].to(torch.long)].to(device) | |
pb = torch.cat((pb_1,pb_2,pb_3),dim=1) | |
# UNLABELED | |
shape = list(last_input.size()[2:]) | |
Ux = last_input.view([-1] + shape) # Ux already augmented (K items) | |
with torch.no_grad(): | |
p1,p2,p3 = self.learn.model(last_input[:, 0]) | |
p4,p5,p6= self.learn.model(last_input[:, 1]) | |
p1 = sharpen(torch.softmax(torch.stack((p1,p4),dim=1),dim=2).mean(dim=1),T=self.T) | |
p2 = sharpen(torch.softmax(torch.stack((p2,p5),dim=1),dim=2).mean(dim=1),T=self.T) | |
p3 = sharpen(torch.softmax(torch.stack((p3,p6),dim=1),dim=2).mean(dim=1),T=self.T) | |
#Stack one-hot categories side by side for ease of data-handling further down | |
Uy = torch.cat((p1,p2,p3),dim=1) | |
qb = Uy.repeat(1, 2).view((-1, Uy.size(-1))) | |
Wx = torch.cat((Xx.half(), Ux.half()), dim=0) | |
Wy = torch.cat((pb.half(), qb), dim=0) | |
shuffle = torch.randperm(Wx.shape[0]) | |
if np.random.rand()<0.5: | |
#MIX | |
mixed_input, mixed_target = _mixup(Wx,Wy,Wx[shuffle],Wy[shuffle],α=self.α) | |
else: | |
#CUT | |
β = np.random.beta(1.,1.,Wy.size(0)) | |
β = last_input.new(β) | |
for i in range(Wy.size(0)): | |
bbx1, bby1, bbx2, bby2 = rand_bbox(Wx.size(), β[i]) | |
Wx[i,..., bby1:bby2, bbx1:bbx2] = Wx[shuffle[i],..., bby1:bby2, bbx1:bbx2] | |
# adjust lambda to exactly match pixel ratio | |
β[i] = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (last_input.size()[-1] * last_input.size()[-2])) | |
β = β[:,None] | |
y = β * Wy + (1 - β) * Wy[shuffle] | |
mixed_target = y | |
mixed_input = Wx | |
return {"last_input": mixed_input, "last_target": (mixed_target, bs)} | |
def on_batch_end(self, train, **kwargs): | |
if not train: return | |
self.smoothLx.add_value(self.learn.loss_func.Lx) | |
self.smoothLu.add_value(self.learn.loss_func.Lu) | |
self.it += 1 | |
def on_epoch_end(self, last_metrics, **kwargs): | |
return add_metrics(last_metrics, [self.smoothLx.smooth, self.smoothLu.smooth]) | |
def on_train_end(self, **kwargs): | |
"""At the end of training, loss_func and data are returned to their original values, | |
and this calleback is removed""" | |
self.learn.loss_func = self.learn.loss_func.get_old() | |
self.learn.data = self.labeled_data | |
drop_cb_fn(self.learn, 'MixMatchCallback') | |
#Example of usage | |
train_size = 20000 #Number of Labeled Images | |
tfms = get_transforms(...#Your Tfms) | |
list = ImageList.from_df(df,path="",folder="data/train", suffix='.png', cols='image_id', convert_mode='L').split_by_rand_pct(valid_pct=0.2).label_empty() | |
y_train = list.train.y.items | |
y_valid = list.valid.y.items | |
l_idx, u_idx = random_strat_splitter(y_train, train_size=train_size, seed=SEED) | |
data = (ItemLists('.', | |
list.train[l_idx], # labeled train (in this case a subset of the original) | |
list.valid) # labeled valid | |
.label_from_df(cols=['grapheme_root','vowel_diacritic','consonant_diacritic']) | |
.transform(tfms, size=SIZE) | |
.databunch(bs=BS)).normalize(stats) | |
ulist = list.train[u_idx] | |
learn = Learner( | |
data, | |
model, | |
loss_func=Loss_combine(), | |
opt_func=Ranger, | |
metrics=[Metric_grapheme(),Metric_vowel(),Metric_consonant(),Metric_tot()]).mixmatch(ulist, α=.4, λ=75) | |
learn.clip_grad = 1.0 | |
learn.to_fp16() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment