Created
January 16, 2020 01:20
-
-
Save da-steve101/d7da4a3cf8a014e3e78d4f65a5e13d3f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# this function is used for quantizing activations | |
def quantize( zr, k ): # zr => number to quantize, k => number of bits to use | |
scaling = tf.cast( tf.pow( 2.0, k ) - 1, tf.float32 ) | |
return tf.round( scaling * zr )/scaling # round the number to the nearest quantize value | |
# this function applies quantization to activations | |
def shaped_relu( x, k = 1.0 ): # x => number to be quantized, k => number of bits to use | |
act = tf.clip_by_value( x, 0, 1 ) # clip the activation between 0 and 1 to stop overflow issues | |
quant = quantize( act, k ) # quantize the value | |
return act + tf.stop_gradient( quant - act ) # use the stop gradient trick | |
# tf.stop_gradient(quant - act) = quant - act; on forward path | |
# = 0; on the backward path | |
# so returns 'quant' on forward and 'act' on back | |
# use the TWN method | |
def trinarize( x, nu = 1.0 ): # x => the weights to trinarize, nu => the sparsity factor | |
x_shape = x.get_shape() | |
thres = nu * tf.reduce_mean(tf.abs(x)) # calculate the threshold | |
g_e = tf.cast( tf.greater_equal( x, thres ), tf.float32 ) # if x >= thres | |
l_e = tf.cast( tf.less_equal( x, -thres ), tf.float32 ) # if x <= thres | |
unmasked = tf.multiply( g_e + l_e, x ) # if x >= thres or x <= thres multiply with 1, otherwise multiply with 0 | |
# unmasked now has all 0's set correctly | |
eta = tf.reduce_mean( tf.abs( unmasked ) ) # determine the average magnitude of the remaining weights | |
t_x = tf.multiply( l_e, -eta ) # every weight that had x <= thres is now set to -eta | |
t_x = t_x + tf.multiply( g_e, eta ) # add in every weight that had x >= thres and set them to eta | |
return x + tf.stop_gradient( t_x - x ) # use the stop gradient trick to quantize on the forward path and back propogate to the real weights | |
# create a convolutional layer | |
def get_conv_layer( x, training, no_filt = 128, nu = None, act_prec = None ): | |
''' | |
x => the input activations | |
training => a boolean flag to indicate training | |
no_filt => the number of filters for the convolution | |
nu => the sparsity factor for TWN, if set to None then the weights are not quantized | |
act_prec = the number of bits to quantize the activations, if set to None there is no quantization | |
''' | |
if nu is None: # call another function not included here if no TWN | |
return get_conv_layer_full_prec( x, training, no_filt ) | |
filter_shape = [ 3, x.get_shape()[-1], no_filt ] # determine the correct size of the convolution | |
conv_filter = tf.get_variable( "conv_filter", filter_shape ) # make a variable to store the real valued weights | |
tf.summary.histogram( "conv_filter_fp", conv_filter ) # add for debugging | |
conv_filter = q.trinarize( conv_filter, nu = nu ) # call TWN quantization | |
cnn = tf.nn.conv1d( x, conv_filter, 1, padding = "SAME" ) # use the quantized weights to compute the convolution | |
cnn = tf.layers.max_pooling1d( cnn, 2, 2 ) # add max pooling | |
cnn = tf.layers.batch_normalization( cnn, training = training ) # add batch normalization | |
tf.summary.histogram( "conv_dist", cnn ) # for debugging | |
tf.summary.histogram( "conv_filter_tri", conv_filter ) # for debugging | |
if act_prec is not None: # if quantize the activations | |
cnn = q.shaped_relu( cnn, act_prec ) # apply quantization | |
else: | |
cnn = tf.nn.relu( cnn ) | |
return cnn |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment