This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from pytorch_lightning import Trainer | |
| from flash.core.data import download_data | |
| from flash.text import TextClassificationData, TextClassifier |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| download_data('https://pl-flash-data.s3.amazonaws.com/imdb.zip', 'data/') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| datamodule = TextClassificationData.from_files( | |
| train_file="data/imdb/train.csv", | |
| valid_file="data/imdb/valid.csv", | |
| test_file="data/imdb/test.csv", | |
| input="review", | |
| target="sentiment" | |
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| model = TextClassifier(num_classes = 2, backbone = 'roberta-base') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| trainer = flash.Trainer(max_epochs = 1) | |
| trainer.finetune(model, datamodule = datamodule) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| trainer = flash.Trainer(gpus=8) | |
| trainer = flash.Trainer(gpus=8, num_nodes=32) | |
| trainer = flash.Trainer(tpu_cores=1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| trainer.test() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Save Checkpoint | |
| trainer.save_checkpoint("text_class_model.pt") | |
| # Load Model From Checkpoint | |
| model = TextClassifier.load_from_checkpoint("text_class_model.pt") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| prediction = model.predict("This movie is great!") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Trainer(num_gpus=32).predict(millions_of_reviews) |