Created
September 2, 2022 15:54
-
-
Save opparco/89a9cf2977c16521849889cd67dc6a1a to your computer and use it in GitHub Desktop.
Get the number of tokens using the same tokenizer that Stable Diffusion uses.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Get the number of tokens using the same tokenizer that Stable Diffusion uses. | |
author: opparco | |
""" | |
import argparse | |
from transformers import CLIPTokenizer | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--prompt", | |
type=str, | |
nargs="?", | |
default="a painting of a virus monster playing guitar", | |
help="the prompt to render" | |
) | |
parser.add_argument( | |
"--from-file", | |
type=str, | |
help="if specified, load prompts from this file", | |
) | |
opt = parser.parse_args() | |
if not opt.from_file: | |
prompt = opt.prompt | |
assert prompt is not None | |
prompts = [prompt] | |
else: | |
print(f"reading prompts from {opt.from_file}") | |
with open(opt.from_file, "r") as f: | |
prompts = f.read().splitlines() | |
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") | |
# print(f"max length of tokenizer: {tokenizer.model_max_length}") | |
for prompt in prompts: | |
print(f"prompt: {prompt}") | |
tokens = tokenizer.tokenize(prompt) | |
print(f"length of tokens: {len(tokens)}") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment