Skip to content

Instantly share code, notes, and snippets.

@XinDongol
Created October 8, 2018 23:44
Show Gist options
  • Save XinDongol/14b5e6273fd8af5a3b7683947038a2eb to your computer and use it in GitHub Desktop.
Save XinDongol/14b5e6273fd8af5a3b7683947038a2eb to your computer and use it in GitHub Desktop.
HWGQ-TF
def get_hwgq(bitA):
def quantize(x, k):
# in order of
assert k in [2,3,4,5], 'Does not support %d bits' % k
code_book={
'2':[0.5380, 0., 0.5380*(2**2-1)],
'3':[0.3218, 0., 0.3218*(2**3-1)],
'4':[0.1813, 0., 0.1813*(2**4-1)],
'5':[0.1029, 0., 0.1029*(2**5-1)]
}
delta, minv, maxv = code_book[str(k)]
#print(delta,minv,maxv)
@tf.custom_gradient
def _quantize(x):
return tf.to_float(x>0.)*(tf.clip_by_value((tf.floor(x/delta + 0.5)+tf.to_float(x<0.5*delta))*delta, minv, maxv)), lambda dy: dy*tf.to_float(x>minv)*tf.to_float(x<maxv)
return _quantize(x)
def fa(x):
if bitA == 32:
return x
return quantize(x, bitA)
return fa
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment