You should replace the arguments in pl.Trainer with Ray Train's implementations.
import pytorch_lightning as pl
+ from ray.train.lightning import (
+ get_devices,
+ prepare_trainer,| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import torch_xla.core.xla_model as xm | |
| import torch_xla.distributed.xla_backend # noqa: F401 | |
| from ray.train import ScalingConfig | |
| from ray.train.torch import TorchTrainer, prepare_model | |
| from ray.train.torch.xla import TorchXLAConfig |
| import os | |
| import tempfile | |
| import torch | |
| from torch import nn | |
| from torch.nn.parallel import DistributedDataParallel | |
| import ray | |
| from ray.train import Checkpoint, CheckpointConfig, RunConfig, ScalingConfig | |
| from ray.train.torch import TorchTrainer |
| accelerate==0.19.0 | |
| adal==1.2.7 | |
| aiofiles==22.1.0 | |
| aiohttp==3.8.5 | |
| aiohttp-cors==0.7.0 | |
| aiorwlock==1.3.0 | |
| aiosignal==1.3.1 | |
| aiosqlite==0.19.0 | |
| alabaster==0.7.13 | |
| anyio==3.7.1 |
| accelerate==0.19.0 | |
| adal==1.2.7 | |
| aiofiles==22.1.0 | |
| aiohttp==3.8.5 | |
| aiohttp-cors==0.7.0 | |
| aiorwlock==1.3.0 | |
| aiosignal==1.3.1 | |
| aiosqlite==0.19.0 | |
| alabaster==0.7.13 | |
| anyio==3.7.1 |
| about-time==4.2.1 | |
| absl-py==1.4.0 | |
| accelerate==0.19.0 | |
| adal==1.2.7 | |
| aim==3.17.5 | |
| aim-ui==3.17.5 | |
| aimrecords==0.0.7 | |
| aimrocks==0.4.0 | |
| aioboto3==11.2.0 | |
| aiobotocore==2.5.0 |
| # Minimal Example adapted from https://huggingface.co/docs/transformers/training | |
| import deepspeed | |
| import evaluate | |
| import torch | |
| from datasets import load_dataset | |
| from deepspeed.accelerator import get_accelerator | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from transformers import ( | |
| AutoModelForSequenceClassification, |
| import evaluate | |
| import torch | |
| from datasets import load_dataset | |
| from torch.optim import AdamW | |
| from torch.utils.data import DataLoader | |
| from transformers import ( | |
| AutoModelForSequenceClassification, | |
| AutoTokenizer, | |
| get_linear_schedule_with_warmup, | |
| set_seed, |
| import os | |
| import evaluate | |
| import numpy as np | |
| from datasets import load_dataset | |
| from ray.train import RunConfig, ScalingConfig, CheckpointConfig, Checkpoint | |
| from ray.train.torch import TorchTrainer | |
| from transformers import AutoTokenizer | |
| from transformers import ( | |
| AutoModelForSequenceClassification, | |
| DataCollatorWithPadding, |
| import os | |
| import time | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from filelock import FileLock | |
| from torch.utils.data import DataLoader, random_split | |
| from torchmetrics import Accuracy | |
| from torchvision.datasets import MNIST | |
| from torchvision import transforms |