Created
September 28, 2023 08:01
-
-
Save Ryu1845/8da008ac7e0520aee1313df4fe906145 to your computer and use it in GitHub Desktop.
FSQ Implementation from the paper
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def round_ste(z): | |
"""Round with straight through gradients.""" | |
zhat = jnp.round(z) | |
return z + jax.lax.stop_gradient(zhat - z) | |
class FSQ: | |
def __init__(self, levels: list[int]): | |
self._levels = levels | |
self._levels_np = np.asarray(levels) | |
self._basis = np.concatenate( | |
([1], np.cumprod(self._levels_np[:-1])) | |
).astype(np.uint32) | |
codebook_size = np.prod(levels) | |
self.implicit_codebook = self.indexes_to_codes( | |
np.arange(codebook_size)) | |
def bound(self, z): | |
"""Bound `z`, an array of shape (..., d).""" | |
eps = 1e-3 | |
half_l = (self._levels_np - 1) * (1 - eps) / 2 | |
offset = jnp.where(self._levels_np % | |
shift = jnp.tan(offset / half_l) | |
return jnp.tanh(z + shift) * half_l - offset | |
def quantize(self, z): | |
"""Quanitzes z, returns quantized zhat, same shape as z.""" | |
quantized = round_ste(self.bound(z)) | |
half_width = self._levels_np // 2 # Renormalize to [-1, 1]. | |
return quantized / half_width | |
def _scale_and_shift(self, zhat_normalized): | |
half_width = self._levels_np // 2 | |
return (zhat_normalized * half_width) + half_width | |
def _scale_and_shift_inverse(self, zhat): | |
half_width = self._levels_np // 2 | |
return (zhat - half_width) / half_width | |
def codes_to_indexes(self, zhat): | |
"""Converts a `code` to an index in the codebook.""" | |
assert zhat.shape[-1] == len(self._levels) | |
zhat = self._scale_and_shift(zhat) | |
return (zhat * self._basis).sum(axis=-1).astype(jnp.uint32) | |
def indexes_to_codes(self, indices): | |
"""Inverse of `indexes_to_codes`.""" | |
indices = indices[..., jnp.newaxis] | |
codes_non_centered = np.mod( | |
np.floor_divide(indices, self._basis), self._levels_np | |
) | |
return self._scale_and_shift_inverse(codes_non_centered |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment