Skip to content

Instantly share code, notes, and snippets.

@joecummings
Created May 6, 2024 21:54
Show Gist options
  • Save joecummings/df85bf16cce22f10482fef1ec8a65dd2 to your computer and use it in GitHub Desktop.
Save joecummings/df85bf16cce22f10482fef1ec8a65dd2 to your computer and use it in GitHub Desktop.
model = TransformerDecoder(...)
model.setup_caches(bsz=8, dtype=torch.float)
generate_0(...)
def generate_0(
model: TransformerDecoder,
prompt: torch.Tensor,
...
) -> torch.Tensor:
bsz = prompt.size(1)
if model.layer[0].kv_cache is not None:
curr_kv_cache_bsz = model.layer[0].kv_cache
if curr_kv_cache_bsz != bsz:
model.setup_caches(max_batch_size=bsz, dtype=prompt.dtype)
else:
model.reset_caches()
# do generation
return output
# User doesn't have to worry about any additional
# calls or configuration for cache"
generate_0(...)
# ---------------------------------------- #
model = TransformerDecoder(...)
model.setup_caches(bsz=8, dtype=torch.float)
generate_1(...)
def generate_1(
model: TransformerDecoder,
prompt: torch.Tensor,
...
) -> torch.Tensor:
# Do a bunch of generation
# Then, reset the cache every time
model.reset_caches()
return output
# User can just call generate_1 if bsz is the same
generate_1(...)
# If bsz changes, they need to change
model.setup_caches(bsz=12, dtype=torch.float)
generate_1(...)
# ---------------------------------------- #
model = TransformerDecoder(...)
model.setup_cache(bsz=8, dtype=torch.float)
generate_2(...)
def generate_2(
model: TransformerDecoder,
prompt: torch.Tensor,
...
) -> torch.Tensor:
# Assume the user has setup the correct kv cache
# Do a bunch of generation
return output
# If same bsz, like for most eval
model.reset_cache()
generate_2(...)
# If different bsz, setup new caches
model.setup_caches(bsz=12, dtype=torch.float)
generate_2(...)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment