Skip to content

Instantly share code, notes, and snippets.

View aribornstein's full-sized avatar

PythicCoder aribornstein

View GitHub Profile
from pytorch_lightning import Trainer
from flash.core.data import download_data
from flash.text import TextClassificationData, TextClassifier
download_data('https://pl-flash-data.s3.amazonaws.com/imdb.zip', 'data/')
datamodule = TextClassificationData.from_files(
train_file="data/imdb/train.csv",
valid_file="data/imdb/valid.csv",
test_file="data/imdb/test.csv",
input="review",
target="sentiment"
)
model = TextClassifier(num_classes = 2, backbone = 'roberta-base')
trainer = flash.Trainer(max_epochs = 1)
trainer.finetune(model, datamodule = datamodule)
trainer = flash.Trainer(gpus=8)
trainer = flash.Trainer(gpus=8, num_nodes=32)
trainer = flash.Trainer(tpu_cores=1)
# Save Checkpoint
trainer.save_checkpoint("text_class_model.pt")
# Load Model From Checkpoint
model = TextClassifier.load_from_checkpoint("text_class_model.pt")
prediction = model.predict("This movie is great!")
Trainer(num_gpus=32).predict(millions_of_reviews)