This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def load_base_model(model_path, epoch, ctx, layer_name=None, n_inputs=2): | |
""" Loads the model from given model path | |
and returns a subnetwork that gives output from layer_name | |
""" | |
net = gluon.nn.SymbolBlock.imports( | |
model_path + "-symbol.json", | |
['data%i' % i for i in range(n_inputs)], | |
model_path + "-%.4d.params" % epoch, | |
ctx=ctx, | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
few_words = ['great', 'excellent', 'best', 'perfect', 'wonderful', 'well', | |
'fun', 'love', 'amazing', 'also', 'enjoyed', 'favorite', 'it', | |
'and', 'loved', 'highly', 'bit', 'job', 'today', 'beautiful', | |
'you', 'definitely', 'superb', 'brilliant', 'world', 'liked', | |
'still', 'enjoy', 'life', 'very', 'especially', 'see', 'fantastic', | |
'both', 'shows', 'good', 'may', 'terrific', 'heart', 'classic', | |
'will', 'enjoyable', 'beautifully', 'always', 'true', 'perfectly', | |
'surprised', 'think', 'outstanding', 'most', | |
'bad', 'worst', 'awful', 'waste', 'boring', 'poor', 'terrible', |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train(network, | |
train_data, | |
holdout_data, | |
loss, | |
epochs, | |
ctx, | |
lr=1e-2, | |
wd=1e-5, | |
optimizer='adam'): | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Preparation of network arguments | |
ctx = [mx.gpu(0)] # use a GPU | |
tt = transformer_pipe.named_steps['token2index'] # to get token to integer map | |
max_idx = max(tt.tok2idx.values())+1 # size of vocabulary of all tokens in training data | |
tok_embed_dim = 64 # embedding size of each token | |
review_embed_dim = 50 # embedding size of hidden state in GRU | |
input_output_embed_map = {"token_embed": (max_idx, tok_embed_dim), | |
"hidden_embed": (None, review_embed_dim)} | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class CustomSeqNet(gluon.nn.HybridBlock): | |
""" | |
Custom defined network for sequence data that is used to predict a binary label. | |
""" | |
def __init__(self, input_output_embed_map, dense_sizes=[100], dropouts=[0.2], activation="relu"): | |
""" | |
input_output_embed_map: {"token_embed": (max_tok_idx, tok_embed_dim), "hidden_embed": (,hidden_embed_dim))} | |
""" | |
self.dense_sizes = dense_sizes # list of output dimension of dense layers |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_dataloader(dataset, | |
dataset_type="train", # valid/test | |
batch_size=256, | |
bucket_num=5, | |
shuffle=True, # true for training | |
num_workers=1): | |
# Batchify function appends the length of each sequence to feed as addtional input | |
combined_batchify_fn = nlp.data.batchify.Tuple( | |
nlp.data.batchify.Pad(axis=0, ret_length=True), |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Tokenize(BaseEstimator, TransformerMixin): | |
""" | |
Takes in pandas series and applies tokenization on each row based on given split pattern. | |
""" | |
def __init__(self, split_pat=f"([{string.punctuation}])"): | |
self.split_pat = split_pat # re pattern used to split string to tokens. default splits over any string punctuation | |
def tokenize(self, s): | |
""" Tokenize string """ | |
re_tok = re.compile(self.split_pat) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Transformers present in Pipeline work in series. 1->2->3->4 | |
transformer_pipe = Pipeline(steps=[ | |
("null_impute", NullImputer(strategy="constant", fill_value="null")), | |
("lower_case", LowerCaser()), | |
("tokenize", Tokenize(f"([{string.punctuation}])")), | |
("token2index", Tok2Idx()) | |
]) | |
X_train_transformed = transformer_pipe.fit_transform(X_train) # fit and transform on train data | |
X_valid_transformed = transformer_pipe.transform(X_valid) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
NewerOlder