Created
October 8, 2018 23:44
-
-
Save XinDongol/14b5e6273fd8af5a3b7683947038a2eb to your computer and use it in GitHub Desktop.
HWGQ-TF
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_hwgq(bitA): | |
def quantize(x, k): | |
# in order of | |
assert k in [2,3,4,5], 'Does not support %d bits' % k | |
code_book={ | |
'2':[0.5380, 0., 0.5380*(2**2-1)], | |
'3':[0.3218, 0., 0.3218*(2**3-1)], | |
'4':[0.1813, 0., 0.1813*(2**4-1)], | |
'5':[0.1029, 0., 0.1029*(2**5-1)] | |
} | |
delta, minv, maxv = code_book[str(k)] | |
#print(delta,minv,maxv) | |
@tf.custom_gradient | |
def _quantize(x): | |
return tf.to_float(x>0.)*(tf.clip_by_value((tf.floor(x/delta + 0.5)+tf.to_float(x<0.5*delta))*delta, minv, maxv)), lambda dy: dy*tf.to_float(x>minv)*tf.to_float(x<maxv) | |
return _quantize(x) | |
def fa(x): | |
if bitA == 32: | |
return x | |
return quantize(x, bitA) | |
return fa |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment