Skip to content

Instantly share code, notes, and snippets.

@Ryu1845
Created October 10, 2024 16:19
Show Gist options
  • Save Ryu1845/5a8774877487602a6fd92261453e181a to your computer and use it in GitHub Desktop.
Save Ryu1845/5a8774877487602a6fd92261453e181a to your computer and use it in GitHub Desktop.
def apply_p_rope(
inputs: jax.Array, # [B, L]
positions: jax.Array, # [B, L]
head_dim: int,
max_wavelength: int = _MAX_WAVELENGTH,
rope_percentage: float = 1.0,
) -> jax.Array:
"""Applies p-RoPE."""
rope_angles = int(rope_percentage * head_dim // 2)
nope_angles = head_dim // 2 - rope_angles
fraction = 2. * jnp.arange(0, rope_angles) / head_dim
timescale = max_wavelength**fraction
timescale = jnp.pad(
max_wavelength**fraction,
(0, nope_angles),
mode=’constant’,
constant_values=(0, jnp.inf)
)
sinusoid_inp = (
positions[..., jnp.newaxis] / timescale[jnp.newaxis, jnp.newaxis, :]
)
sinusoid_inp = sinusoid_inp[..., jnp.newaxis, :]
sin = jnp.sin(sinusoid_inp)
cos = jnp.cos(sinusoid_inp)
first_half, second_half = jnp.split(inputs, 2, axis=-1)
first_part = first_half * cos - second_half * sin
second_part = second_half * cos + first_half * sin
out = jnp.concatenate([first_part, second_part], axis=-1)
return out.astype(inputs.dtype)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment