Skip to content

Instantly share code, notes, and snippets.

@Helw150
Last active May 10, 2023 14:52
Show Gist options
  • Save Helw150/f7ec01a4cdca13686b2508c37fe3a9b1 to your computer and use it in GitHub Desktop.
Save Helw150/f7ec01a4cdca13686b2508c37fe3a9b1 to your computer and use it in GitHub Desktop.
Flan T5 Parallel Usage
from transformers import AutoTokenizer, T5ForConditionalGeneration
# Model Init
n_gpu = 8
tokenizer = AutoTokenizer.from_pretrained("google/flan-ul2")
model = T5ForConditionalGeneration.from_pretrained("google/flan-ul2")
heads_per_gpu = len(model.encoder.block) // n_gpu
device_map = {
gpu: list(
range(
0 + (gpu * heads_per_gpu),
(0 + (gpu * heads_per_gpu)) + heads_per_gpu,
)
)
for gpu in range(n_gpu)
}
model.parallelize(device_map)
# training
prompt: "{}\nWhat is the sentiment of this review?\npositive\nnegative\nneutral\n".format("I hate this movie!")
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
labels = tokenizer("positive", return_tensors="pt").input_ids
outputs = model(input_ids=input_ids, labels=labels)
loss = outputs.loss
logits = outputs.logits
loss.backward()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment