Skip to content

Instantly share code, notes, and snippets.

View Chris-hughes10's full-sized avatar

Chris Hughes Chris-hughes10

View GitHub Profile
@Chris-hughes10
Chris-hughes10 / train_with_metrics_in_loop.py
Created November 24, 2021 11:11
pytorch_accelerated_blog_metrics_in_trainer_script
# https://github.com/Chris-hughes10/pytorch-accelerated/blob/main/examples/metrics/train_with_metrics_in_loop.py
import os
from torch import nn, optim
from torch.utils.data import random_split
from torchmetrics import MetricCollection, Accuracy, Precision, Recall
from torchvision import transforms
from torchvision.datasets import MNIST
from pytorch_accelerated import Trainer
@Chris-hughes10
Chris-hughes10 / trainer_with_metrics.py
Created November 24, 2021 11:02
pytorch_accelerated_blog_trainer_metrics_snippet
from torchmetrics import MetricCollection, Accuracy, Precision, Recall
class TrainerWithMetrics(Trainer):
def __init__(self, num_classes, *args, **kwargs):
super().__init__(*args, **kwargs)
# this will be moved to the correct device automatically by the
# MoveModulesToDeviceCallback callback, which is used by default
self.metrics = MetricCollection(
{
@Chris-hughes10
Chris-hughes10 / train_mnist.py
Last active November 24, 2021 10:53
pytorch-accelerated_blog_mnist_quickstart
# this example is taken from
# https://github.com/Chris-hughes10/pytorch-accelerated/blob/main/examples/train_mnist.py
import os
from torch import nn, optim
from torch.utils.data import random_split
from torchvision import transforms
from torchvision.datasets import MNIST
@Chris-hughes10
Chris-hughes10 / EfficientDet Pytorch-lightning with EfficientNet v2 backbone Blog Post.ipynb
Last active February 5, 2025 01:44
EfficientDet Pytorch-lightning with EfficientNet v2 backbone Blog Post.ipynb
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
from objdetecteval.metrics.coco_metrics import get_coco_stats
@patch
def validation_epoch_end(self: EfficientDetModel, outputs):
"""Compute and log training loss and accuracy at the epoch level."""
validation_loss_mean = torch.stack(
[output["loss"] for output in outputs]
).mean()
@Chris-hughes10
Chris-hughes10 / effdet_aggregate_outputs.py
Created July 16, 2021 09:52
Effdet_blog_aggregate_outputs
from fastcore.basics import patch
@patch
def aggregate_prediction_outputs(self: EfficientDetModel, outputs):
detections = torch.cat(
[output["batch_predictions"]["predictions"] for output in outputs]
)
image_ids = []
@Chris-hughes10
Chris-hughes10 / effdet_run_inference.py
Created July 16, 2021 09:46
Effdet_blog_inference
def _run_inference(self, images_tensor, image_sizes):
dummy_targets = self._create_dummy_inference_targets(
num_images=images_tensor.shape[0]
)
detections = self.model(images_tensor.to(self.device), dummy_targets)[
"detections"
]
(
predicted_bboxes,
@typedispatch
def predict(self, images: List):
"""
For making predictions from images
Args:
images: a list of PIL images
Returns: a tuple of lists containing bboxes, predicted_class_labels, predicted_class_confidences
"""
@Chris-hughes10
Chris-hughes10 / effdet_model_1.py
Created July 16, 2021 09:40
Effdet_blog_model_1
import torch
from pytorch_lightning import LightningModule
from pytorch_lightning.core.decorators import auto_move_data
class EfficientDetModel(LightningModule):
def __init__(
self,
num_classes=1,
img_size=512,
prediction_confidence_threshold=0.2,
@Chris-hughes10
Chris-hughes10 / effdet_datamodule.py
Created July 16, 2021 09:36
Effdet_blog_datamodule
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
class EfficientDetDataModule(LightningDataModule):
def __init__(self,
train_dataset_adaptor,
validation_dataset_adaptor,
train_transforms=get_train_transforms(target_img_size=512),
valid_transforms=get_valid_transforms(target_img_size=512),