Skip to content

Instantly share code, notes, and snippets.

@karino2
Last active August 1, 2019 09:44
Show Gist options
  • Save karino2/2bdaf2e01bcfac12b06610a1e2b8c2da to your computer and use it in GitHub Desktop.
Save karino2/2bdaf2e01bcfac12b06610a1e2b8c2da to your computer and use it in GitHub Desktop.
VOCAB_SIZE=115
MAX_ONE_STROKE_LEN=50
MAX_STROKE_NUM=22
MAX_TOKEN_LEN=10+2
# must match to trained dim
EXTRACTED_FEATURE_DIM=256
FE_DROPOUT_RATE=0.5
FE_L2_REGULARIZATION_RATE=0.01
ENCODER_DROPOUT_RATE=0.1
DECODER_DROPOUT_RATE=0.1
L2_REGULARIZATION_RATE=1e-6
FEATURE_EXTRACTER_KERNEL_SIZE=7
EMBEDDING_SIZE=32
def fe_conv1d(filternum, kernelsize, x):
return tf.layers.Conv1D(filternum, kernelsize, kernel_regularizer=regularizers.l2(L2_REGULARIZATION_RATE), bias_regularizer=regularizers.l2(L2_REGULARIZATION_RATE), activity_regularizer=regularizers.l2(L2_REGULARIZATION_RATE))(x)
def feature_extractor(input_stroke_t, is_training):
"""input_stroke_t shape (batch, MAX_STROKE_NUM, MAX_ONE_STROKE_LEN, INPUT_TYPE_DIM)"""
with tf.variable_scope("feature_extractor"):
inpshape = input_stroke_t.shape
x = tf.reshape(input_stroke_t, [-1, inpshape[2], inpshape[3]])
# (batch*MAX_STROKE_NUM, MAX_ONE_STROKE_LEN, INPUT_TYPE_DIM)
x = fe_conv1d(32, FEATURE_EXTRACTER_KERNEL_SIZE, x)
x = tf.layers.BatchNormalization()(x, training=is_training)
x = Activation('relu')(x)
# (batch*MAX_STROKE_NUM, MAX_ONE_STROKE_LEN, 32)
x = tf.layers.MaxPooling1D(pool_size=2, strides=2)(x)
x = tf.layers.Dropout(FE_DROPOUT_RATE)(x, training=is_training)
# (batch*MAX_STROKE_NUM, MAX_ONE_STROKE_LEN/2, 32)
x = fe_conv1d(64, FEATURE_EXTRACTER_KERNEL_SIZE, x)
x = tf.layers.BatchNormalization()(x, training=is_training)
x = Activation('relu')(x)
x = tf.layers.MaxPooling1D(pool_size=2, strides=2)(x)
x = tf.layers.Dropout(FE_DROPOUT_RATE)(x, training=is_training)
# (batch*MAX_STROKE_NUM, MAX_ONE_STROKE_LEN/4, 64)
x = fe_conv1d(EXTRACTED_FEATURE_DIM, FEATURE_EXTRACTER_KERNEL_SIZE, x)
x = tf.layers.BatchNormalization()(x, training=is_training)
x = Activation('relu')(x)
x = tf.layers.Dropout(FE_DROPOUT_RATE)(x, training=is_training)
x = GlobalMaxPooling1D()(x)
x = tf.reshape(x, [-1, inpshape[1], EXTRACTED_FEATURE_DIM])
return x
TRANSFORMER_H = 8
TRANSFORMER_D=64
TRANSFORMER_D_MODEL=TRANSFORMER_H*TRANSFORMER_D
def layer_norm(x):
"""Layer norm only for last dimension."""
return tf.contrib.layers.layer_norm(inputs=x, begin_norm_axis=-1, begin_params_axis=-1)
# same as tensorflow official implementation, but specify each dim for XLA
# https://github.com/tensorflow/models/blob/master/official/transformer/model/attention_layer.py
def split_heads(x, seqlen, num_heads=TRANSFORMER_H, one_depth=TRANSFORMER_D):
x = tf.reshape(x, [-1, seqlen, num_heads, one_depth])
return tf.transpose(x, [0, 2, 1, 3])
def multihead_attention(query, query_seq_len, y, y_seq_len, mask_bias, dropout_rate, is_training):
num_head = TRANSFORMER_H
size_per_head = TRANSFORMER_D
multi_q = tf.layers.Dense(num_head*size_per_head, use_bias=False)(query)
multi_k = tf.layers.Dense(num_head*size_per_head, use_bias=False)(y)
multi_v = tf.layers.Dense(num_head*size_per_head, use_bias=False)(y)
k_seq_len = v_seq_len = y_seq_len
multi_q = split_heads(multi_q, query_seq_len)
multi_k = split_heads(multi_k, k_seq_len)
multi_v = split_heads(multi_v, v_seq_len)
multi_q *= TRANSFORMER_D ** -0.5
# Calculate dot product attention
logits = tf.matmul(multi_q, multi_k, transpose_b=True)
logits += mask_bias
weights = tf.nn.softmax(logits)
weights = tf.layers.Dropout(dropout_rate)(weights, training=is_training)
attention_output = tf.matmul(weights, multi_v)
# Recombine heads --> [batch_size, num_head, query_seq_len, size_per_head]
attention_output = tf.transpose(attention_output, [0, 2, 1, 3]) # --> [batch, query_seq_len, num_heads, size_per_head]
attention_output = tf.reshape(attention_output, [-1, query_seq_len, TRANSFORMER_D_MODEL])
attention_output = tf.layers.Dense(TRANSFORMER_D_MODEL, use_bias=False)(attention_output)
return attention_output
def myembedding(input, num_classes, embedding_size, seq_num, name):
with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
randinitializer = lambda: tf.random_uniform([num_classes, embedding_size], -0.05, 0.05)
embedmat = tf.get_variable(name, initializer = randinitializer)
onehot = tf.one_hot(input, num_classes)
flatten_onehot = tf.reshape(onehot, [-1, num_classes])
return tf.reshape(tf.matmul(flatten_onehot, embedmat), [-1, seq_num, embedding_size])
def embed_stroke(stroke_features):
pos_stroke = tf.range(
0,
tf.shape(stroke_features)[1],
delta=1,
dtype=tf.int32,
name='range')
pos_stroke = tf.expand_dims(pos_stroke, axis=0)
pos_stroke_embed = myembedding(pos_stroke, MAX_STROKE_NUM, EXTRACTED_FEATURE_DIM, MAX_STROKE_NUM, "stroke_pos_embed")
stroke_pos_embedded = stroke_features + tf.cast(x=pos_stroke_embed, dtype=stroke_features.dtype)
return stroke_pos_embedded
def mask_to_maskbias(mask):
# unmasked element to by -infinity.
mask = tf.cast(mask, tf.float32)
mask_bias = (-1e9)*(1.-mask)
mask_bias = tf.expand_dims(tf.expand_dims(mask_bias, axis=1), axis=1)
# [batch, 1, 1, length]
return mask_bias
def encSelfAttenOneBlock(input, mask_bias, is_training):
attention_output = multihead_attention(input, MAX_STROKE_NUM, input, MAX_STROKE_NUM, mask_bias, ENCODER_DROPOUT_RATE, is_training)
attention_output = SpatialDropout1D(ENCODER_DROPOUT_RATE)(attention_output, training=is_training)
attention_output = layer_norm(attention_output + input)
intermediate_output = tf.layers.Dense(2048,activation='relu', use_bias=False)(attention_output)
layer_output = tf.layers.Dense(TRANSFORMER_D_MODEL,use_bias=False)(intermediate_output)
layer_output = SpatialDropout1D(ENCODER_DROPOUT_RATE)(layer_output, training=is_training)
return layer_norm(layer_output+attention_output)
def encoder_SelfAttention(input, mask_bias, is_training):
x = tf.layers.Dense(TRANSFORMER_D_MODEL,use_bias=False)(input)
x = SpatialDropout1D(ENCODER_DROPOUT_RATE)(x, training=is_training)
x = encSelfAttenOneBlock(x, mask_bias, is_training)
x = encSelfAttenOneBlock(x, mask_bias, is_training)
x = encSelfAttenOneBlock(x, mask_bias, is_training)
x = encSelfAttenOneBlock(x, mask_bias, is_training)
x = encSelfAttenOneBlock(x, mask_bias, is_training)
x = encSelfAttenOneBlock(x, mask_bias, is_training)
return x
# dynamic shape cause TPUEstimator export to fail...
# For transformer decoder, use the same weight for embedding and reverse-embedding
class SharedEmbedder:
def __init__(self, num_classes, embeding_size=TRANSFORMER_D_MODEL,seq_num=MAX_TOKEN_LEN, name="dec_embed"):
self.num_classes = num_classes
self.embedding_size = embeding_size
self.seq_num = seq_num
with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
randinitializer = lambda: tf.random_uniform([num_classes, embeding_size], -0.05, 0.05)
self.embedmat = tf.get_variable(name, initializer = randinitializer)
def embed(self, input):
onehot = tf.one_hot(input, self.num_classes)
flatten_onehot = tf.reshape(onehot, [-1, self.num_classes])
return tf.reshape(tf.matmul(flatten_onehot, self.embedmat), [-1, self.seq_num, self.embedding_size])
def reverse_embed(self,x):
x = tf.reshape(x, [-1, self.embedding_size])
logits = tf.matmul(x, self.embedmat, transpose_b=True)
return tf.reshape(logits, [-1, self.seq_num, self.num_classes])
def embed_decoder_trans(decoder_input_t):
embedder = SharedEmbedder(VOCAB_SIZE, TRANSFORMER_D_MODEL, MAX_TOKEN_LEN, "dec_embed")
dec_input_embedded = embedder.embed(decoder_input_t)
dec_pos_input = tf.range(
0,
tf.shape(decoder_input_t)[1],
delta=1,
dtype=tf.int32,
name='range')
dec_pos_input = tf.expand_dims(dec_pos_input, axis=0)
dec_pos_embed = myembedding(dec_pos_input, MAX_TOKEN_LEN, TRANSFORMER_D_MODEL, MAX_TOKEN_LEN, "dec_pos_embed")
dec_embedded = dec_input_embedded + tf.cast(x=dec_pos_embed, dtype=dec_input_embedded.dtype)
return dec_embedded, embedder
# mask for decoder
def subsequent_mask_bias(size):
mask = tf.linalg.band_part(tf.ones([size, size], dtype=tf.float32),-1, 0)
mask = tf.reshape(mask, [1, 1, size, size])
mask_bias = (-1e9)*(1.-mask)
return mask_bias
def decoderTransOneBlock(decoder_inputs, decoder_mask_bias, ht_enc, stroke_mask_bias, is_training):
ht_dec = multihead_attention(decoder_inputs, MAX_TOKEN_LEN, decoder_inputs, MAX_TOKEN_LEN, decoder_mask_bias, DECODER_DROPOUT_RATE, is_training)
ht_dec = SpatialDropout1D(DECODER_DROPOUT_RATE)(ht_dec, training=is_training)
ht_dec = layer_norm(ht_dec + decoder_inputs)
attention_output = multihead_attention(ht_dec, MAX_TOKEN_LEN, ht_enc, MAX_STROKE_NUM, stroke_mask_bias, DECODER_DROPOUT_RATE, is_training)
attention_output = SpatialDropout1D(DECODER_DROPOUT_RATE)(attention_output, training=is_training)
attention_output = layer_norm(attention_output + ht_dec)
intermediate_output = tf.layers.Dense(2048,activation='relu', use_bias=False)(attention_output)
layer_output = tf.layers.Dense(TRANSFORMER_D_MODEL,use_bias=False)(intermediate_output)
layer_output = SpatialDropout1D(DECODER_DROPOUT_RATE)(layer_output, training=is_training)
return layer_norm(layer_output+attention_output)
def decoder_Transformer(dec_input, ht_enc, stroke_mask_bias, is_training):
dec_mask_bias = subsequent_mask_bias(MAX_TOKEN_LEN)
ht_dec = SpatialDropout1D(DECODER_DROPOUT_RATE)(dec_input, training=is_training)
ht_dec = decoderTransOneBlock(ht_dec, dec_mask_bias, ht_enc, stroke_mask_bias, is_training)
ht_dec = decoderTransOneBlock(ht_dec, dec_mask_bias, ht_enc, stroke_mask_bias, is_training)
ht_dec = decoderTransOneBlock(ht_dec, dec_mask_bias, ht_enc, stroke_mask_bias, is_training)
ht_dec = decoderTransOneBlock(ht_dec, dec_mask_bias, ht_enc, stroke_mask_bias, is_training)
ht_dec = decoderTransOneBlock(ht_dec, dec_mask_bias, ht_enc, stroke_mask_bias, is_training)
ht_dec = decoderTransOneBlock(ht_dec, dec_mask_bias, ht_enc, stroke_mask_bias, is_training)
ht_dec = decoderTransOneBlock(ht_dec, dec_mask_bias, ht_enc, stroke_mask_bias, is_training)
return ht_dec
def create_model(input_stroke_t, stroke_mask, decoder_input_t, is_training):
stroke_features = feature_extractor(input_stroke_t, is_training)
# (batch, MAX_STROKE_NUM, EXTRACTED_FEATURE_DIM)
stroke_embedded = embed_stroke(stroke_features)
dec_embedded, embedder = embed_decoder_trans(decoder_input_t)
stroke_mask_bias = mask_to_maskbias(stroke_mask)
ht_enc = encoder_SelfAttention(stroke_embedded, stroke_mask_bias, is_training)
dec_ht = decoder_Transformer(dec_embedded, ht_enc, stroke_mask_bias, is_training)
logit = embedder.reverse_embed(dec_ht)
return logit
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment