Last active
February 20, 2025 05:34
-
-
Save ugo-nama-kun/5e2a62e744cc9860e44b88e6d26b886c to your computer and use it in GitHub Desktop.
mnist handler example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torchvision | |
import random | |
import matplotlib.pyplot as plt | |
class MNISTHandler: | |
def __init__(self, train=True): | |
""" | |
Initializes the MNIST dataset. | |
Args: | |
train (bool): If True, loads the training dataset; otherwise, loads the test dataset. | |
""" | |
self.dataset = torchvision.datasets.MNIST(root="./data", train=train, download=True) | |
self.class_indices = self._index_classes() | |
def _index_classes(self): | |
""" | |
Creates a dictionary mapping each class to the indices of its occurrences in the dataset. | |
Returns: | |
dict: A dictionary where keys are class labels and values are lists of indices. | |
""" | |
class_indices = {i: [] for i in range(10)} | |
for i, label in enumerate(self.dataset.targets): | |
class_indices[label.item()].append(i) | |
return class_indices | |
def get_random_image(self, class_number): | |
""" | |
Returns a random image from the specified class. | |
Args: | |
class_number (int): The class label (0-9) to retrieve an image from. | |
Returns: | |
PIL.Image: The image corresponding to the specified class. | |
""" | |
if class_number not in self.class_indices: | |
raise ValueError("Invalid class number. Choose a number between 0 and 9.") | |
random_index = random.choice(self.class_indices[class_number]) | |
image, label = self.dataset[random_index] | |
return image | |
if __name__ == '__main__': | |
# Example usage | |
mnist_handler = MNISTHandler(train=True) | |
for i in range(10): | |
img = mnist_handler.get_random_image(i) | |
# Display the image | |
plt.imshow(img, cmap="gray") | |
plt.title(f"Random MNIST Image of {i}") | |
plt.axis("off") | |
plt.pause(0.3) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment