Created
April 20, 2023 07:09
-
-
Save woshiyyya/794af26853782ca4ce2477ea101bd6f1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import ray | |
| import torch | |
| import pytorch_lightning as pl | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader, random_split | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from datasets import load_dataset, load_metric | |
| import numpy as np | |
| from ray.train.torch import TorchTrainer | |
| dataset = load_dataset("glue", "cola") | |
| metric = load_metric("glue", "cola") | |
| ray_datasets = ray.data.from_huggingface(dataset) | |
| from ray.data.preprocessors import BatchMapper | |
| tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") | |
| def tokenize_sentence(batch): | |
| encoded_sent = tokenizer( | |
| batch["sentence"].tolist(), | |
| max_length=128, | |
| truncation=True, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| batch["input_ids"] = encoded_sent["input_ids"].numpy() | |
| batch["attention_mask"] = encoded_sent["attention_mask"].numpy() | |
| batch["label"] = np.array(batch["label"]) | |
| batch.pop("sentence") | |
| return batch | |
| preprocessor = BatchMapper(tokenize_sentence, batch_format="numpy") | |
| class SentimentModel(pl.LightningModule): | |
| def __init__(self, lr=2e-5, eps=1e-8): | |
| super().__init__() | |
| self.lr = lr | |
| self.eps = eps | |
| self.num_classes = 2 | |
| self.model = AutoModelForSequenceClassification.from_pretrained( | |
| "bert-base-cased", num_labels=self.num_classes | |
| ) | |
| self.metric = load_metric("glue", "cola") | |
| self.predictions = [] | |
| self.references = [] | |
| def forward(self, batch): | |
| if isinstance(batch, tuple): | |
| input_ids, attention_mask, _ = batch[0] | |
| else: | |
| input_ids, attention_mask = batch["input_ids"], batch["attention_mask"] | |
| outputs = self.model(input_ids, attention_mask=attention_mask) | |
| logits = outputs.logits | |
| return logits | |
| def training_step(self, batch, batch_idx): | |
| if isinstance(batch, tuple): | |
| _, _, labels = batch[0] | |
| else: | |
| labels = batch["label"] | |
| logits = self.forward(batch) | |
| loss = F.cross_entropy(logits.view(-1, self.num_classes), labels) | |
| self.log("train_loss", loss) | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| if isinstance(batch, tuple): | |
| _, _, labels = batch[0] | |
| else: | |
| labels = batch["label"] | |
| logits = self.forward(batch) | |
| preds = torch.argmax(logits, dim=1) | |
| self.predictions.append(preds) | |
| self.references.append(labels) | |
| def on_validation_epoch_end(self): | |
| predictions = torch.concat(self.predictions).view(-1) | |
| references = torch.concat(self.references).view(-1) | |
| matthews_correlation = self.metric.compute( | |
| predictions=predictions, references=references | |
| ) | |
| # self.metric.compute() returns a dictionary: | |
| # e.g. {"matthews_correlation": 0.53} | |
| self.log_dict(matthews_correlation, sync_dist=True) | |
| self.predictions.clear() | |
| self.references.clear() | |
| def configure_optimizers(self): | |
| return torch.optim.AdamW(self.parameters(), lr=self.lr, eps=self.eps) | |
| from ray.train.lightning import LightningTrainer, LightningConfigBuilder | |
| from ray.air.config import RunConfig, ScalingConfig, CheckpointConfig | |
| # Save AIR checkpoints according to the performance on validation set | |
| run_config = RunConfig( | |
| name="ptl-sent-classification", | |
| checkpoint_config=CheckpointConfig( | |
| num_to_keep=2, | |
| checkpoint_score_attribute="matthews_correlation", | |
| checkpoint_score_order="max", | |
| ), | |
| ) | |
| # Scale the training workload across 4 GPUs | |
| # You can change this config based on your compute resources. | |
| scaling_config = ScalingConfig( | |
| num_workers=4, use_gpu=True, resources_per_worker={"CPU": 1, "GPU": 1} | |
| ) | |
| def test_lightning_trainer(batch_size=16): | |
| # Define the configs for LightningTrainer | |
| lightning_config = ( | |
| LightningConfigBuilder() | |
| .module(cls=SentimentModel, lr=1e-5, eps=1e-8) | |
| .trainer(max_epochs=5, accelerator="gpu") | |
| .checkpointing(save_on_train_epoch_end=False) | |
| .build() | |
| ) | |
| trainer = LightningTrainer( | |
| lightning_config=lightning_config, | |
| run_config=run_config, | |
| scaling_config=scaling_config, | |
| datasets={"train": ray_datasets["train"], "val": ray_datasets["validation"]}, | |
| datasets_iter_config={"batch_size": batch_size}, | |
| preprocessor=preprocessor, | |
| ) | |
| result = trainer.fit() | |
| def torch_trainer_worker_loop(config): | |
| train_data = ray.air.session.get_dataset_shard("train").to_torch( | |
| feature_columns=[["input_ids"], ["attention_mask"], ["label"]], | |
| batch_size=config["batch_size"], | |
| unsqueeze_feature_tensors=False | |
| ) | |
| val_data = ray.air.session.get_dataset_shard("val").to_torch( | |
| feature_columns=[["input_ids"], ["attention_mask"], ["label"]], | |
| batch_size=config["batch_size"], | |
| unsqueeze_feature_tensors=False | |
| ) | |
| model = SentimentModel(lr=1e-5, eps=1e-8) | |
| trainer = pl.Trainer(max_epochs=5, accelerator="gpu", strategy="ddp") | |
| trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) | |
| def test_torch_trainer(batch_size=16): | |
| torch_trainer = TorchTrainer( | |
| train_loop_per_worker=torch_trainer_worker_loop, | |
| train_loop_config={"batch_size": batch_size}, | |
| run_config=run_config, | |
| scaling_config=scaling_config, | |
| datasets={"train": ray_datasets["train"], "val": ray_datasets["validation"]}, | |
| preprocessor=preprocessor, | |
| ) | |
| result = torch_trainer.fit() | |
| # test_torch_trainer(batch_size=64) | |
| test_lightning_trainer(batch_size=64) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment