Skip to content

Instantly share code, notes, and snippets.

@samuellangajr
Created March 12, 2025 18:49
Show Gist options
  • Save samuellangajr/337a9dc62f82e35ca9bfcf41df8dec82 to your computer and use it in GitHub Desktop.
Save samuellangajr/337a9dc62f82e35ca9bfcf41df8dec82 to your computer and use it in GitHub Desktop.
Implemente uma classe Python para um dataset personalizado que possa ser usado com PyTorch DataLoader. O dataset deve carregar imagens de veículos e seus rótulos correspondentes.
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
# Definir a classe do Dataset personalizado
class VehicleDataset(Dataset):
def __init__(self, imagens_dir, rotulos_file, transform=None):
"""
Args:
imagens_dir (str): Caminho para o diretório das imagens.
rotulos_file (str): Caminho para o arquivo de rótulos (ex: CSV ou TXT).
transform (callable, optional): Função para aplicar transformações nas imagens.
"""
self.imagens_dir = imagens_dir
self.rotulos_file = rotulos_file
self.transform = transform
# Carregar os rótulos (assumindo que são armazenados em um arquivo CSV ou similar)
self.imagens = []
self.rotulos = []
# Supondo que o arquivo de rótulos tenha 2 colunas: nome da imagem e o rótulo
with open(rotulos_file, 'r') as file:
for linha in file:
caminho_imagem, rotulo = linha.strip().split(',')
self.imagens.append(caminho_imagem)
self.rotulos.append(int(rotulo))
def __len__(self):
# Retorna o número de itens no dataset
return len(self.imagens)
def __getitem__(self, idx):
"""
Recupera uma imagem e seu rótulo correspondente.
Args:
idx (int): Índice do item no dataset.
Returns:
imagem (Tensor): A imagem transformada.
rotulo (int): O rótulo da imagem.
"""
# Carregar a imagem
imagem_path = os.path.join(self.imagens_dir, self.imagens[idx])
imagem = Image.open(imagem_path).convert('RGB')
# Aplicar as transformações (se houver)
if self.transform:
imagem = self.transform(imagem)
# Obter o rótulo correspondente
rotulo = self.rotulos[idx]
return imagem, rotulo
# Definir as transformações que serão aplicadas nas imagens
transformacoes = transforms.Compose([
transforms.Resize((224, 224)), # Redimensionar para 224x224
transforms.ToTensor(), # Converter para tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalização para imagens RGB
])
# Caminho para o diretório de imagens e o arquivo de rótulos
imagens_dir = 'caminho/para/imagens' # Substitua pelo diretório real
rotulos_file = 'caminho/para/rotulos.csv' # Substitua pelo caminho real
# Criar o dataset
dataset = VehicleDataset(imagens_dir, rotulos_file, transform=transformacoes)
# Criar o DataLoader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
# Exemplo de iteração no DataLoader
for imagens, rotulos in dataloader:
print(imagens.shape) # Imprime o tamanho do batch de imagens
print(rotulos) # Imprime os rótulos correspondentes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment