Skip to content

Instantly share code, notes, and snippets.

@woshiyyya
Last active May 12, 2023 20:10
Show Gist options
  • Select an option

  • Save woshiyyya/7b158c31926b9c6da4af99bdf6dbe224 to your computer and use it in GitHub Desktop.

Select an option

Save woshiyyya/7b158c31926b9c6da4af99bdf6dbe224 to your computer and use it in GitHub Desktop.
class MNISTClassifier(pl.LightningModule):
def __init__(self, config):
super(MNISTClassifier, self).__init__()
self.accuracy = Accuracy()
# [!] Determine your data augmentation strategy here
self.batch_size = config["batch_size"]
self.aug_strategy = config["augmentation_strategy"]
if self.aug_strategy == "strategy_a":
self.transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
elif self.aug_strategy == "strategy_b":
self.transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
else:
self.transform = any_customized_augmentation_func()
self.layer_1_size = config["layer_1_size"]
self.layer_2_size = config["layer_2_size"]
self.lr = config["lr"]
# mnist images are (1, 28, 28) (channels, width, height)
self.layer_1 = torch.nn.Linear(28 * 28, self.layer_1_size)
self.layer_2 = torch.nn.Linear(self.layer_1_size, self.layer_2_size)
self.layer_3 = torch.nn.Linear(self.layer_2_size, 10)
def cross_entropy_loss(self, logits, labels):
return F.nll_loss(logits, labels)
def forward(self, x):
batch_size, channels, width, height = x.size()
x = x.view(batch_size, -1)
x = self.layer_1(x)
x = torch.relu(x)
x = self.layer_2(x)
x = torch.relu(x)
x = self.layer_3(x)
x = torch.log_softmax(x, dim=1)
return x
def training_step(self, train_batch, batch_idx):
x, y = train_batch
logits = self.forward(x)
loss = self.cross_entropy_loss(logits, y)
accuracy = self.accuracy(logits, y)
self.log("ptl/train_loss", loss)
self.log("ptl/train_accuracy", accuracy)
return loss
def validation_step(self, val_batch, batch_idx):
x, y = val_batch
logits = self.forward(x)
loss = self.cross_entropy_loss(logits, y)
accuracy = self.accuracy(logits, y)
return {"val_loss": loss, "val_accuracy": accuracy}
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
self.log("ptl/val_loss", avg_loss)
self.log("ptl/val_accuracy", avg_acc)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
def train_dataloader(self):
# [!] Initialize your PyTorch Dataset on each worker
mnist_train = MNIST(self.data_dir, train=True, download=True, transform=self.transform)
return DataLoader(mnist_train, batch_size=self.batch_size, num_workers=4)
def val_dataloader(self):
mnist_val = MNIST(self.data_dir, train=False, download=True, transform=self.transform)
return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=4)
# ....
# [!] Tune different `augmentation_strategy` according to the flag
config = {
"layer_1_size": tune.choice([32, 64, 128]),
"layer_2_size": tune.choice([64, 128, 256]),
"lr": tune.loguniform(1e-4, 1e-1),
"augmentation_strategy": tune.grid_search(["strategy_a", "strategy_b"])
}
lightning_config = (
LightningConfigBuilder()
.module(cls=MNISTClassifier, config=config)
.trainer(max_epochs=num_epochs, accelerator=accelerator, logger=logger)
.checkpointing(monitor="ptl/val_accuracy", save_top_k=2, mode="max")
.build()
)
# Nothing changes needed for Trainer and Tuner initialization
lightning_trainer = LightningTrainer(
scaling_config=ScalingConfig(
num_workers=3, use_gpu=False, resources_per_worker={"CPU": 1}
),
run_config=run_config,
)
tuner = tune.Tuner(
lightning_trainer,
param_space={"lightning_config": lightning_config},
tune_config=tune.TuneConfig(
metric="ptl/val_accuracy",
mode="max",
num_samples=num_samples,
scheduler=scheduler,
),
run_config=air.RunConfig(
name="tune_mnist_asha",
),
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment