Created
May 6, 2024 21:54
-
-
Save joecummings/df85bf16cce22f10482fef1ec8a65dd2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model = TransformerDecoder(...) | |
model.setup_caches(bsz=8, dtype=torch.float) | |
generate_0(...) | |
def generate_0( | |
model: TransformerDecoder, | |
prompt: torch.Tensor, | |
... | |
) -> torch.Tensor: | |
bsz = prompt.size(1) | |
if model.layer[0].kv_cache is not None: | |
curr_kv_cache_bsz = model.layer[0].kv_cache | |
if curr_kv_cache_bsz != bsz: | |
model.setup_caches(max_batch_size=bsz, dtype=prompt.dtype) | |
else: | |
model.reset_caches() | |
# do generation | |
return output | |
# User doesn't have to worry about any additional | |
# calls or configuration for cache" | |
generate_0(...) | |
# ---------------------------------------- # | |
model = TransformerDecoder(...) | |
model.setup_caches(bsz=8, dtype=torch.float) | |
generate_1(...) | |
def generate_1( | |
model: TransformerDecoder, | |
prompt: torch.Tensor, | |
... | |
) -> torch.Tensor: | |
# Do a bunch of generation | |
# Then, reset the cache every time | |
model.reset_caches() | |
return output | |
# User can just call generate_1 if bsz is the same | |
generate_1(...) | |
# If bsz changes, they need to change | |
model.setup_caches(bsz=12, dtype=torch.float) | |
generate_1(...) | |
# ---------------------------------------- # | |
model = TransformerDecoder(...) | |
model.setup_cache(bsz=8, dtype=torch.float) | |
generate_2(...) | |
def generate_2( | |
model: TransformerDecoder, | |
prompt: torch.Tensor, | |
... | |
) -> torch.Tensor: | |
# Assume the user has setup the correct kv cache | |
# Do a bunch of generation | |
return output | |
# If same bsz, like for most eval | |
model.reset_cache() | |
generate_2(...) | |
# If different bsz, setup new caches | |
model.setup_caches(bsz=12, dtype=torch.float) | |
generate_2(...) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment