Skip to content

Instantly share code, notes, and snippets.

@woshiyyya
Created April 20, 2023 07:09
Show Gist options
  • Select an option

  • Save woshiyyya/794af26853782ca4ce2477ea101bd6f1 to your computer and use it in GitHub Desktop.

Select an option

Save woshiyyya/794af26853782ca4ce2477ea101bd6f1 to your computer and use it in GitHub Desktop.
import ray
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset, load_metric
import numpy as np
from ray.train.torch import TorchTrainer
dataset = load_dataset("glue", "cola")
metric = load_metric("glue", "cola")
ray_datasets = ray.data.from_huggingface(dataset)
from ray.data.preprocessors import BatchMapper
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
def tokenize_sentence(batch):
encoded_sent = tokenizer(
batch["sentence"].tolist(),
max_length=128,
truncation=True,
padding="max_length",
return_tensors="pt",
)
batch["input_ids"] = encoded_sent["input_ids"].numpy()
batch["attention_mask"] = encoded_sent["attention_mask"].numpy()
batch["label"] = np.array(batch["label"])
batch.pop("sentence")
return batch
preprocessor = BatchMapper(tokenize_sentence, batch_format="numpy")
class SentimentModel(pl.LightningModule):
def __init__(self, lr=2e-5, eps=1e-8):
super().__init__()
self.lr = lr
self.eps = eps
self.num_classes = 2
self.model = AutoModelForSequenceClassification.from_pretrained(
"bert-base-cased", num_labels=self.num_classes
)
self.metric = load_metric("glue", "cola")
self.predictions = []
self.references = []
def forward(self, batch):
if isinstance(batch, tuple):
input_ids, attention_mask, _ = batch[0]
else:
input_ids, attention_mask = batch["input_ids"], batch["attention_mask"]
outputs = self.model(input_ids, attention_mask=attention_mask)
logits = outputs.logits
return logits
def training_step(self, batch, batch_idx):
if isinstance(batch, tuple):
_, _, labels = batch[0]
else:
labels = batch["label"]
logits = self.forward(batch)
loss = F.cross_entropy(logits.view(-1, self.num_classes), labels)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
if isinstance(batch, tuple):
_, _, labels = batch[0]
else:
labels = batch["label"]
logits = self.forward(batch)
preds = torch.argmax(logits, dim=1)
self.predictions.append(preds)
self.references.append(labels)
def on_validation_epoch_end(self):
predictions = torch.concat(self.predictions).view(-1)
references = torch.concat(self.references).view(-1)
matthews_correlation = self.metric.compute(
predictions=predictions, references=references
)
# self.metric.compute() returns a dictionary:
# e.g. {"matthews_correlation": 0.53}
self.log_dict(matthews_correlation, sync_dist=True)
self.predictions.clear()
self.references.clear()
def configure_optimizers(self):
return torch.optim.AdamW(self.parameters(), lr=self.lr, eps=self.eps)
from ray.train.lightning import LightningTrainer, LightningConfigBuilder
from ray.air.config import RunConfig, ScalingConfig, CheckpointConfig
# Save AIR checkpoints according to the performance on validation set
run_config = RunConfig(
name="ptl-sent-classification",
checkpoint_config=CheckpointConfig(
num_to_keep=2,
checkpoint_score_attribute="matthews_correlation",
checkpoint_score_order="max",
),
)
# Scale the training workload across 4 GPUs
# You can change this config based on your compute resources.
scaling_config = ScalingConfig(
num_workers=4, use_gpu=True, resources_per_worker={"CPU": 1, "GPU": 1}
)
def test_lightning_trainer(batch_size=16):
# Define the configs for LightningTrainer
lightning_config = (
LightningConfigBuilder()
.module(cls=SentimentModel, lr=1e-5, eps=1e-8)
.trainer(max_epochs=5, accelerator="gpu")
.checkpointing(save_on_train_epoch_end=False)
.build()
)
trainer = LightningTrainer(
lightning_config=lightning_config,
run_config=run_config,
scaling_config=scaling_config,
datasets={"train": ray_datasets["train"], "val": ray_datasets["validation"]},
datasets_iter_config={"batch_size": batch_size},
preprocessor=preprocessor,
)
result = trainer.fit()
def torch_trainer_worker_loop(config):
train_data = ray.air.session.get_dataset_shard("train").to_torch(
feature_columns=[["input_ids"], ["attention_mask"], ["label"]],
batch_size=config["batch_size"],
unsqueeze_feature_tensors=False
)
val_data = ray.air.session.get_dataset_shard("val").to_torch(
feature_columns=[["input_ids"], ["attention_mask"], ["label"]],
batch_size=config["batch_size"],
unsqueeze_feature_tensors=False
)
model = SentimentModel(lr=1e-5, eps=1e-8)
trainer = pl.Trainer(max_epochs=5, accelerator="gpu", strategy="ddp")
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
def test_torch_trainer(batch_size=16):
torch_trainer = TorchTrainer(
train_loop_per_worker=torch_trainer_worker_loop,
train_loop_config={"batch_size": batch_size},
run_config=run_config,
scaling_config=scaling_config,
datasets={"train": ray_datasets["train"], "val": ray_datasets["validation"]},
preprocessor=preprocessor,
)
result = torch_trainer.fit()
# test_torch_trainer(batch_size=64)
test_lightning_trainer(batch_size=64)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment