This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pytorch_lightning import Trainer | |
from flash.core.data import download_data | |
from flash.text import TextClassificationData, TextClassifier |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
download_data('https://pl-flash-data.s3.amazonaws.com/imdb.zip', 'data/') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
datamodule = TextClassificationData.from_files( | |
train_file="data/imdb/train.csv", | |
valid_file="data/imdb/valid.csv", | |
test_file="data/imdb/test.csv", | |
input="review", | |
target="sentiment" | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model = TextClassifier(num_classes = 2, backbone = 'roberta-base') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
trainer = flash.Trainer(max_epochs = 1) | |
trainer.finetune(model, datamodule = datamodule) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
trainer = flash.Trainer(gpus=8) | |
trainer = flash.Trainer(gpus=8, num_nodes=32) | |
trainer = flash.Trainer(tpu_cores=1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
trainer.test() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Save Checkpoint | |
trainer.save_checkpoint("text_class_model.pt") | |
# Load Model From Checkpoint | |
model = TextClassifier.load_from_checkpoint("text_class_model.pt") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
prediction = model.predict("This movie is great!") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Trainer(num_gpus=32).predict(millions_of_reviews) |