Last active
May 10, 2023 14:52
-
-
Save Helw150/f7ec01a4cdca13686b2508c37fe3a9b1 to your computer and use it in GitHub Desktop.
Flan T5 Parallel Usage
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoTokenizer, T5ForConditionalGeneration | |
# Model Init | |
n_gpu = 8 | |
tokenizer = AutoTokenizer.from_pretrained("google/flan-ul2") | |
model = T5ForConditionalGeneration.from_pretrained("google/flan-ul2") | |
heads_per_gpu = len(model.encoder.block) // n_gpu | |
device_map = { | |
gpu: list( | |
range( | |
0 + (gpu * heads_per_gpu), | |
(0 + (gpu * heads_per_gpu)) + heads_per_gpu, | |
) | |
) | |
for gpu in range(n_gpu) | |
} | |
model.parallelize(device_map) | |
# training | |
prompt: "{}\nWhat is the sentiment of this review?\npositive\nnegative\nneutral\n".format("I hate this movie!") | |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids | |
labels = tokenizer("positive", return_tensors="pt").input_ids | |
outputs = model(input_ids=input_ids, labels=labels) | |
loss = outputs.loss | |
logits = outputs.logits | |
loss.backward() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment