Last active
June 24, 2020 05:13
-
-
Save ranihorev/8a25b8038f14c96cbba5fc3717247245 to your computer and use it in GitHub Desktop.
PyTorch module for classification or regression of categorical+continuous+text inputs. This module is based on fast.ai library
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MyDataset(Dataset): | |
def __init__(self, cats, conts, texts, y, is_reg, is_multi, reverse_text=False): | |
n = len(cats[0]) if cats else len(conts[0]) | |
self.cats = np.stack(cats, 1).astype(np.int64) if cats else np.zeros((n,1)) | |
self.conts = np.stack(conts, 1).astype(np.float32) if conts else np.zeros((n,1)) | |
self.texts = np.zeros((n,1)) if texts is None else np.array(texts) | |
self.y = np.zeros((n,1)) if y is None else np.array(y).reshape(-1, 1).astype(np.float32) | |
if is_reg: | |
self.y = self.y[:,None] | |
self.is_reg = is_reg | |
self.is_multi = is_multi | |
self.reverse_text = reverse_text | |
def __len__(self): return len(self.y) | |
def __getitem__(self, idx): | |
t = self.texts[idx] | |
if self.reverse_text: | |
t = list(reversed(t)) | |
return [self.cats[idx], self.conts[idx], np.array(t), self.y[idx]] | |
@classmethod | |
def from_data_frame(cls, df, cat_flds, cont_flds, text_fld, y_fld, is_reg=True, is_multi=False): | |
cat_cols = [c.values for n,c in df[cat_flds].items()] | |
cont_cols = [c.values for n,c in df[cont_flds].items()] | |
return cls(cat_cols, cont_cols, df[text_fld], df[y_fld], is_reg, is_multi) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MyModel(BasicModel): | |
def get_layer_groups(self): | |
m=self.model | |
return [m.rnn_enc, m.structured_model] | |
class MyLearner(Learner): | |
def __init__(self, data, models, **kwargs): | |
super().__init__(data, models, **kwargs) | |
def _get_crit(self, data): | |
if data.is_reg: | |
return F.mse_loss | |
elif data.is_multi: | |
return F.binary_cross_entropy | |
else: | |
return F.nll_loss | |
def predict_array(self,x_cat,x_cont, text): | |
self.model.eval() | |
return to_np(self.model(to_gpu(V(T(x_cat))),to_gpu(V(T(x_cont))), to_gpu(V(T(text))))) | |
def summary(self): | |
x = [torch.ones(3, self.data.trn_ds.cats.shape[1]).long(), torch.rand(3, self.data.trn_ds.conts.shape[1])] | |
return model_summary(self.model, x) | |
def save_encoder(self, name): save_model(self.model[0], self.get_model_path(name)) | |
def load_encoder(self, name): load_model(self.model[0], self.get_model_path(name)) | |
def build_learner(trn_df, val_df, cat_flds, cont_flds, text_fld, y_fld): | |
trn_ds = MyDataset.from_data_frame(trn_df, cat_flds, cont_flds, text_fld, y_fld) | |
val_ds = MyDataset.from_data_frame(val_df, cat_flds, cont_flds, text_fld, y_fld) | |
trn_samp = SortishSampler(trn_df[text_fld], key=lambda x: len(trn_df[text_fld].iloc[x]), bs=bs//2) | |
val_samp = SortSampler(val_df[text_fld], key=lambda x: len(val_df[text_fld].iloc[x])) | |
trn_dl = DataLoader(trn_ds, bs//2, num_workers=1, pad_idx=1, sampler=trn_samp) | |
# return trn_dl | |
val_dl = DataLoader(val_ds, bs, num_workers=1, pad_idx=1, sampler=val_samp) | |
md = ModelData(PATH, trn_dl, val_dl) | |
model = RNN_Structured_regressor(text_bptt, text_max_seq, text_ntoken, text_emb_sz, text_n_hid, | |
text_n_layers, text_pad_token, struct_emb_szs, struct_n_cont, y_range,) | |
model_wrapper = MyModel(to_gpu(model)) | |
return MyLearner(md, model_wrapper, opt_fn=optim.Adam) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastai.text import * | |
from fastai.structured import proc_df | |
import pandas as pd | |
import numpy as np | |
class MixedInputModelWithText(nn.Module): | |
def __init__(self, emb_szs, n_cont, emb_drop, out_sz, szs, drops, | |
y_range=None, use_bn=False, is_reg=True, is_multi=False, n_text=0): | |
super().__init__() | |
for i, (c, s) in enumerate(emb_szs): assert c > 1, f"cardinality must be >=2, got emb_szs[{i}]: ({c},{s})" | |
if is_reg == False and is_multi == False: assert out_sz >= 2, "For classification with out_sz=1, use is_multi=True" | |
self.embs = nn.ModuleList([nn.Embedding(c, s) for c, s in emb_szs]) | |
for emb in self.embs: self.emb_init(emb) | |
n_emb = sum(e.embedding_dim for e in self.embs) | |
self.n_emb, self.n_cont, self.text_emb_sz = n_emb, n_cont, n_text | |
szs = [n_emb + n_cont + n_text] + szs | |
self.lins = nn.ModuleList([ | |
nn.Linear(szs[i], szs[i + 1]) for i in range(len(szs) - 1)]) | |
self.bns = nn.ModuleList([ | |
nn.BatchNorm1d(sz) for sz in szs[1:]]) | |
for o in self.lins: kaiming_normal(o.weight.data) | |
self.outp = nn.Linear(szs[-1], out_sz) | |
kaiming_normal(self.outp.weight.data) | |
self.emb_drop = nn.Dropout(emb_drop) | |
self.drops = nn.ModuleList([nn.Dropout(drop) for drop in drops]) | |
self.bn = nn.BatchNorm1d(n_cont) | |
self.use_bn, self.y_range = use_bn, y_range | |
self.is_reg = is_reg | |
self.is_multi = is_multi | |
def emb_init(self, x): | |
x = x.weight.data | |
sc = 2/(x.size(1)+1) | |
x.uniform_(-sc,sc) | |
def forward(self, x_cat, x_cont, text_emb): | |
if self.n_emb != 0: | |
x1 = [e(x_cat[:, i]) for i, e in enumerate(self.embs)] | |
x1 = torch.cat(x1, 1) | |
x1 = self.emb_drop(x1) | |
else: | |
x1 = torch.Tensor() | |
if self.text_emb_sz != 0: | |
x2 = self.emb_drop(text_emb) | |
else: | |
x2 = torch.Tensor() | |
if self.n_cont != 0: | |
x3 = self.bn(x_cont) | |
all_xs = [x1, x2, x3] | |
all_xs = [cur_x for cur_x in all_xs if cur_x.nelement() != 0] | |
x = torch.cat(all_xs, 1) | |
for l, d, b in zip(self.lins, self.drops, self.bns): | |
x = F.relu(l(x)) | |
if self.use_bn: x = b(x) | |
x = d(x) | |
x = self.outp(x) | |
if not self.is_reg: | |
if self.is_multi: | |
x = F.sigmoid(x) | |
else: | |
x = F.log_softmax(x) | |
elif self.y_range: | |
x = F.sigmoid(x) | |
x = x * (self.y_range[1] - self.y_range[0]) | |
x = x + self.y_range[0] | |
return x | |
class MyMultiBatchRNN(RNN_Encoder): | |
def __init__(self, bptt, max_seq, *args, **kwargs): | |
self.max_seq,self.bptt = max_seq,bptt | |
super().__init__(*args, **kwargs) | |
self.reset() | |
def concat(self, arrs): | |
return [torch.cat([l[si] for l in arrs]) for si in range(len(arrs[0]))] | |
def forward(self, input): | |
sl,bs = input.size() | |
for l in self.hidden: | |
for h in l: h.data.zero_() | |
raw_outputs, outputs = [],[] | |
for i in range(0, sl, self.bptt): | |
r, o = super().forward(input[i: min(i+self.bptt, sl)]) | |
if i>(sl-self.max_seq): | |
raw_outputs.append(r) | |
outputs.append(o) | |
return self.concat(raw_outputs), self.concat(outputs) | |
class RNN_Structured_regressor(nn.Module): | |
def __init__(self, text_bptt, text_max_seq, text_ntoken, text_emb_sz, text_n_hid, text_n_layers, text_pad_token, | |
struct_emb_szs, struct_n_cont, y_range, struct_layers_szs=[1000,500]): | |
super().__init__() | |
self.rnn_enc = MyMultiBatchRNN(bptt=text_bptt, max_seq=text_max_seq, ntoken=text_ntoken, emb_sz=text_emb_sz, | |
n_hid=text_n_hid, n_layers=text_n_layers, pad_token=text_pad_token, | |
dropouth=0.3, dropouti=0.65, dropoute=0.1, wdrop=0.5, qrnn=False) | |
# struct_emb_szs = [(text_emb_sz, text_emb_sz)] + struct_emb_szs | |
self.structured_model = MixedInputModelWithText(struct_emb_szs, struct_n_cont, emb_drop=0.04, out_sz=1, | |
szs=struct_layers_szs, drops=[0.001,0.01], y_range=y_range, | |
use_bn=False, is_reg=True, is_multi=False, n_text=text_emb_sz) | |
def forward(self, x_cat, x_cont, text_inp): | |
raw_outputs, outputs = self.rnn_enc(torch.t(text_inp)) | |
encoded_text = outputs[-1][-1] # add max pooling afterwards | |
return self.structured_model(x_cat, x_cont, encoded_text) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment