Last active
December 8, 2023 09:55
-
-
Save pacman100/5aac746b0a7bdee5dca23e2f27cc4fb0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from accelerate import Accelerator | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP | |
import contextlib | |
MODEL_NAME = "meta-llama/Llama-2-70b-chat-hf" #"HuggingFaceH4/zephyr-7b-beta" | |
def main(): | |
accelerator = Accelerator() | |
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
if accelerator.state.deepspeed_plugin is not None: | |
accelerator.state.deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu']=1 | |
model = accelerator.prepare(model) | |
sample_texts = [ | |
[{"role": "user", "content": "Explain Deep Learning like a Pirate."}], | |
[{"role": "user", "content": "Why is it important to eat socks daily?"}], | |
[{"role": "user", "content": "Write a tweet about the latest model by Google Gemini which is topping all the benchmarks"}], | |
[{"role": "user", "content": "How do I convert a Python dictionary into a string representation?"}] | |
] | |
for i in range(len(sample_texts)): | |
sample_texts[i] = tokenizer.apply_chat_template(sample_texts[i], add_generation_prompt=True, tokenize=False) | |
accelerator.print(sample_texts) | |
inputs = tokenizer(sample_texts[accelerator.process_index], return_tensors="pt").to(accelerator.device) | |
ctx = FSDP.summon_full_params(model, writeback=False, recurse=False) if hasattr(accelerator.state, "fsdp_plugin") is not None else contextlib.nullcontext() | |
unwrapped_model = accelerator.unwrap_model(model) | |
with ctx: | |
outputs = unwrapped_model.generate(**inputs, | |
do_sample=True, | |
temperature=0.2, | |
top_p=0.95, | |
eos_token_id=tokenizer.eos_token_id, | |
pad_token_id=tokenizer.eos_token_id, | |
max_new_tokens=128, | |
synced_gpus=True | |
) | |
print(f"{accelerator.process_index=} {tokenizer.decode(outputs[0], skip_special_tokens=False)}") | |
print("".join(["-"]*100)) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment