Last active
August 1, 2019 09:44
-
-
Save karino2/2bdaf2e01bcfac12b06610a1e2b8c2da to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
VOCAB_SIZE=115 | |
MAX_ONE_STROKE_LEN=50 | |
MAX_STROKE_NUM=22 | |
MAX_TOKEN_LEN=10+2 | |
# must match to trained dim | |
EXTRACTED_FEATURE_DIM=256 | |
FE_DROPOUT_RATE=0.5 | |
FE_L2_REGULARIZATION_RATE=0.01 | |
ENCODER_DROPOUT_RATE=0.1 | |
DECODER_DROPOUT_RATE=0.1 | |
L2_REGULARIZATION_RATE=1e-6 | |
FEATURE_EXTRACTER_KERNEL_SIZE=7 | |
EMBEDDING_SIZE=32 | |
def fe_conv1d(filternum, kernelsize, x): | |
return tf.layers.Conv1D(filternum, kernelsize, kernel_regularizer=regularizers.l2(L2_REGULARIZATION_RATE), bias_regularizer=regularizers.l2(L2_REGULARIZATION_RATE), activity_regularizer=regularizers.l2(L2_REGULARIZATION_RATE))(x) | |
def feature_extractor(input_stroke_t, is_training): | |
"""input_stroke_t shape (batch, MAX_STROKE_NUM, MAX_ONE_STROKE_LEN, INPUT_TYPE_DIM)""" | |
with tf.variable_scope("feature_extractor"): | |
inpshape = input_stroke_t.shape | |
x = tf.reshape(input_stroke_t, [-1, inpshape[2], inpshape[3]]) | |
# (batch*MAX_STROKE_NUM, MAX_ONE_STROKE_LEN, INPUT_TYPE_DIM) | |
x = fe_conv1d(32, FEATURE_EXTRACTER_KERNEL_SIZE, x) | |
x = tf.layers.BatchNormalization()(x, training=is_training) | |
x = Activation('relu')(x) | |
# (batch*MAX_STROKE_NUM, MAX_ONE_STROKE_LEN, 32) | |
x = tf.layers.MaxPooling1D(pool_size=2, strides=2)(x) | |
x = tf.layers.Dropout(FE_DROPOUT_RATE)(x, training=is_training) | |
# (batch*MAX_STROKE_NUM, MAX_ONE_STROKE_LEN/2, 32) | |
x = fe_conv1d(64, FEATURE_EXTRACTER_KERNEL_SIZE, x) | |
x = tf.layers.BatchNormalization()(x, training=is_training) | |
x = Activation('relu')(x) | |
x = tf.layers.MaxPooling1D(pool_size=2, strides=2)(x) | |
x = tf.layers.Dropout(FE_DROPOUT_RATE)(x, training=is_training) | |
# (batch*MAX_STROKE_NUM, MAX_ONE_STROKE_LEN/4, 64) | |
x = fe_conv1d(EXTRACTED_FEATURE_DIM, FEATURE_EXTRACTER_KERNEL_SIZE, x) | |
x = tf.layers.BatchNormalization()(x, training=is_training) | |
x = Activation('relu')(x) | |
x = tf.layers.Dropout(FE_DROPOUT_RATE)(x, training=is_training) | |
x = GlobalMaxPooling1D()(x) | |
x = tf.reshape(x, [-1, inpshape[1], EXTRACTED_FEATURE_DIM]) | |
return x | |
TRANSFORMER_H = 8 | |
TRANSFORMER_D=64 | |
TRANSFORMER_D_MODEL=TRANSFORMER_H*TRANSFORMER_D | |
def layer_norm(x): | |
"""Layer norm only for last dimension.""" | |
return tf.contrib.layers.layer_norm(inputs=x, begin_norm_axis=-1, begin_params_axis=-1) | |
# same as tensorflow official implementation, but specify each dim for XLA | |
# https://github.com/tensorflow/models/blob/master/official/transformer/model/attention_layer.py | |
def split_heads(x, seqlen, num_heads=TRANSFORMER_H, one_depth=TRANSFORMER_D): | |
x = tf.reshape(x, [-1, seqlen, num_heads, one_depth]) | |
return tf.transpose(x, [0, 2, 1, 3]) | |
def multihead_attention(query, query_seq_len, y, y_seq_len, mask_bias, dropout_rate, is_training): | |
num_head = TRANSFORMER_H | |
size_per_head = TRANSFORMER_D | |
multi_q = tf.layers.Dense(num_head*size_per_head, use_bias=False)(query) | |
multi_k = tf.layers.Dense(num_head*size_per_head, use_bias=False)(y) | |
multi_v = tf.layers.Dense(num_head*size_per_head, use_bias=False)(y) | |
k_seq_len = v_seq_len = y_seq_len | |
multi_q = split_heads(multi_q, query_seq_len) | |
multi_k = split_heads(multi_k, k_seq_len) | |
multi_v = split_heads(multi_v, v_seq_len) | |
multi_q *= TRANSFORMER_D ** -0.5 | |
# Calculate dot product attention | |
logits = tf.matmul(multi_q, multi_k, transpose_b=True) | |
logits += mask_bias | |
weights = tf.nn.softmax(logits) | |
weights = tf.layers.Dropout(dropout_rate)(weights, training=is_training) | |
attention_output = tf.matmul(weights, multi_v) | |
# Recombine heads --> [batch_size, num_head, query_seq_len, size_per_head] | |
attention_output = tf.transpose(attention_output, [0, 2, 1, 3]) # --> [batch, query_seq_len, num_heads, size_per_head] | |
attention_output = tf.reshape(attention_output, [-1, query_seq_len, TRANSFORMER_D_MODEL]) | |
attention_output = tf.layers.Dense(TRANSFORMER_D_MODEL, use_bias=False)(attention_output) | |
return attention_output | |
def myembedding(input, num_classes, embedding_size, seq_num, name): | |
with tf.variable_scope(name, reuse=tf.AUTO_REUSE): | |
randinitializer = lambda: tf.random_uniform([num_classes, embedding_size], -0.05, 0.05) | |
embedmat = tf.get_variable(name, initializer = randinitializer) | |
onehot = tf.one_hot(input, num_classes) | |
flatten_onehot = tf.reshape(onehot, [-1, num_classes]) | |
return tf.reshape(tf.matmul(flatten_onehot, embedmat), [-1, seq_num, embedding_size]) | |
def embed_stroke(stroke_features): | |
pos_stroke = tf.range( | |
0, | |
tf.shape(stroke_features)[1], | |
delta=1, | |
dtype=tf.int32, | |
name='range') | |
pos_stroke = tf.expand_dims(pos_stroke, axis=0) | |
pos_stroke_embed = myembedding(pos_stroke, MAX_STROKE_NUM, EXTRACTED_FEATURE_DIM, MAX_STROKE_NUM, "stroke_pos_embed") | |
stroke_pos_embedded = stroke_features + tf.cast(x=pos_stroke_embed, dtype=stroke_features.dtype) | |
return stroke_pos_embedded | |
def mask_to_maskbias(mask): | |
# unmasked element to by -infinity. | |
mask = tf.cast(mask, tf.float32) | |
mask_bias = (-1e9)*(1.-mask) | |
mask_bias = tf.expand_dims(tf.expand_dims(mask_bias, axis=1), axis=1) | |
# [batch, 1, 1, length] | |
return mask_bias | |
def encSelfAttenOneBlock(input, mask_bias, is_training): | |
attention_output = multihead_attention(input, MAX_STROKE_NUM, input, MAX_STROKE_NUM, mask_bias, ENCODER_DROPOUT_RATE, is_training) | |
attention_output = SpatialDropout1D(ENCODER_DROPOUT_RATE)(attention_output, training=is_training) | |
attention_output = layer_norm(attention_output + input) | |
intermediate_output = tf.layers.Dense(2048,activation='relu', use_bias=False)(attention_output) | |
layer_output = tf.layers.Dense(TRANSFORMER_D_MODEL,use_bias=False)(intermediate_output) | |
layer_output = SpatialDropout1D(ENCODER_DROPOUT_RATE)(layer_output, training=is_training) | |
return layer_norm(layer_output+attention_output) | |
def encoder_SelfAttention(input, mask_bias, is_training): | |
x = tf.layers.Dense(TRANSFORMER_D_MODEL,use_bias=False)(input) | |
x = SpatialDropout1D(ENCODER_DROPOUT_RATE)(x, training=is_training) | |
x = encSelfAttenOneBlock(x, mask_bias, is_training) | |
x = encSelfAttenOneBlock(x, mask_bias, is_training) | |
x = encSelfAttenOneBlock(x, mask_bias, is_training) | |
x = encSelfAttenOneBlock(x, mask_bias, is_training) | |
x = encSelfAttenOneBlock(x, mask_bias, is_training) | |
x = encSelfAttenOneBlock(x, mask_bias, is_training) | |
return x | |
# dynamic shape cause TPUEstimator export to fail... | |
# For transformer decoder, use the same weight for embedding and reverse-embedding | |
class SharedEmbedder: | |
def __init__(self, num_classes, embeding_size=TRANSFORMER_D_MODEL,seq_num=MAX_TOKEN_LEN, name="dec_embed"): | |
self.num_classes = num_classes | |
self.embedding_size = embeding_size | |
self.seq_num = seq_num | |
with tf.variable_scope(name, reuse=tf.AUTO_REUSE): | |
randinitializer = lambda: tf.random_uniform([num_classes, embeding_size], -0.05, 0.05) | |
self.embedmat = tf.get_variable(name, initializer = randinitializer) | |
def embed(self, input): | |
onehot = tf.one_hot(input, self.num_classes) | |
flatten_onehot = tf.reshape(onehot, [-1, self.num_classes]) | |
return tf.reshape(tf.matmul(flatten_onehot, self.embedmat), [-1, self.seq_num, self.embedding_size]) | |
def reverse_embed(self,x): | |
x = tf.reshape(x, [-1, self.embedding_size]) | |
logits = tf.matmul(x, self.embedmat, transpose_b=True) | |
return tf.reshape(logits, [-1, self.seq_num, self.num_classes]) | |
def embed_decoder_trans(decoder_input_t): | |
embedder = SharedEmbedder(VOCAB_SIZE, TRANSFORMER_D_MODEL, MAX_TOKEN_LEN, "dec_embed") | |
dec_input_embedded = embedder.embed(decoder_input_t) | |
dec_pos_input = tf.range( | |
0, | |
tf.shape(decoder_input_t)[1], | |
delta=1, | |
dtype=tf.int32, | |
name='range') | |
dec_pos_input = tf.expand_dims(dec_pos_input, axis=0) | |
dec_pos_embed = myembedding(dec_pos_input, MAX_TOKEN_LEN, TRANSFORMER_D_MODEL, MAX_TOKEN_LEN, "dec_pos_embed") | |
dec_embedded = dec_input_embedded + tf.cast(x=dec_pos_embed, dtype=dec_input_embedded.dtype) | |
return dec_embedded, embedder | |
# mask for decoder | |
def subsequent_mask_bias(size): | |
mask = tf.linalg.band_part(tf.ones([size, size], dtype=tf.float32),-1, 0) | |
mask = tf.reshape(mask, [1, 1, size, size]) | |
mask_bias = (-1e9)*(1.-mask) | |
return mask_bias | |
def decoderTransOneBlock(decoder_inputs, decoder_mask_bias, ht_enc, stroke_mask_bias, is_training): | |
ht_dec = multihead_attention(decoder_inputs, MAX_TOKEN_LEN, decoder_inputs, MAX_TOKEN_LEN, decoder_mask_bias, DECODER_DROPOUT_RATE, is_training) | |
ht_dec = SpatialDropout1D(DECODER_DROPOUT_RATE)(ht_dec, training=is_training) | |
ht_dec = layer_norm(ht_dec + decoder_inputs) | |
attention_output = multihead_attention(ht_dec, MAX_TOKEN_LEN, ht_enc, MAX_STROKE_NUM, stroke_mask_bias, DECODER_DROPOUT_RATE, is_training) | |
attention_output = SpatialDropout1D(DECODER_DROPOUT_RATE)(attention_output, training=is_training) | |
attention_output = layer_norm(attention_output + ht_dec) | |
intermediate_output = tf.layers.Dense(2048,activation='relu', use_bias=False)(attention_output) | |
layer_output = tf.layers.Dense(TRANSFORMER_D_MODEL,use_bias=False)(intermediate_output) | |
layer_output = SpatialDropout1D(DECODER_DROPOUT_RATE)(layer_output, training=is_training) | |
return layer_norm(layer_output+attention_output) | |
def decoder_Transformer(dec_input, ht_enc, stroke_mask_bias, is_training): | |
dec_mask_bias = subsequent_mask_bias(MAX_TOKEN_LEN) | |
ht_dec = SpatialDropout1D(DECODER_DROPOUT_RATE)(dec_input, training=is_training) | |
ht_dec = decoderTransOneBlock(ht_dec, dec_mask_bias, ht_enc, stroke_mask_bias, is_training) | |
ht_dec = decoderTransOneBlock(ht_dec, dec_mask_bias, ht_enc, stroke_mask_bias, is_training) | |
ht_dec = decoderTransOneBlock(ht_dec, dec_mask_bias, ht_enc, stroke_mask_bias, is_training) | |
ht_dec = decoderTransOneBlock(ht_dec, dec_mask_bias, ht_enc, stroke_mask_bias, is_training) | |
ht_dec = decoderTransOneBlock(ht_dec, dec_mask_bias, ht_enc, stroke_mask_bias, is_training) | |
ht_dec = decoderTransOneBlock(ht_dec, dec_mask_bias, ht_enc, stroke_mask_bias, is_training) | |
ht_dec = decoderTransOneBlock(ht_dec, dec_mask_bias, ht_enc, stroke_mask_bias, is_training) | |
return ht_dec | |
def create_model(input_stroke_t, stroke_mask, decoder_input_t, is_training): | |
stroke_features = feature_extractor(input_stroke_t, is_training) | |
# (batch, MAX_STROKE_NUM, EXTRACTED_FEATURE_DIM) | |
stroke_embedded = embed_stroke(stroke_features) | |
dec_embedded, embedder = embed_decoder_trans(decoder_input_t) | |
stroke_mask_bias = mask_to_maskbias(stroke_mask) | |
ht_enc = encoder_SelfAttention(stroke_embedded, stroke_mask_bias, is_training) | |
dec_ht = decoder_Transformer(dec_embedded, ht_enc, stroke_mask_bias, is_training) | |
logit = embedder.reverse_embed(dec_ht) | |
return logit |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment