Created
March 12, 2025 18:49
-
-
Save samuellangajr/337a9dc62f82e35ca9bfcf41df8dec82 to your computer and use it in GitHub Desktop.
Implemente uma classe Python para um dataset personalizado que possa ser usado com PyTorch DataLoader. O dataset deve carregar imagens de veículos e seus rótulos correspondentes.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision import transforms | |
from PIL import Image | |
import os | |
# Definir a classe do Dataset personalizado | |
class VehicleDataset(Dataset): | |
def __init__(self, imagens_dir, rotulos_file, transform=None): | |
""" | |
Args: | |
imagens_dir (str): Caminho para o diretório das imagens. | |
rotulos_file (str): Caminho para o arquivo de rótulos (ex: CSV ou TXT). | |
transform (callable, optional): Função para aplicar transformações nas imagens. | |
""" | |
self.imagens_dir = imagens_dir | |
self.rotulos_file = rotulos_file | |
self.transform = transform | |
# Carregar os rótulos (assumindo que são armazenados em um arquivo CSV ou similar) | |
self.imagens = [] | |
self.rotulos = [] | |
# Supondo que o arquivo de rótulos tenha 2 colunas: nome da imagem e o rótulo | |
with open(rotulos_file, 'r') as file: | |
for linha in file: | |
caminho_imagem, rotulo = linha.strip().split(',') | |
self.imagens.append(caminho_imagem) | |
self.rotulos.append(int(rotulo)) | |
def __len__(self): | |
# Retorna o número de itens no dataset | |
return len(self.imagens) | |
def __getitem__(self, idx): | |
""" | |
Recupera uma imagem e seu rótulo correspondente. | |
Args: | |
idx (int): Índice do item no dataset. | |
Returns: | |
imagem (Tensor): A imagem transformada. | |
rotulo (int): O rótulo da imagem. | |
""" | |
# Carregar a imagem | |
imagem_path = os.path.join(self.imagens_dir, self.imagens[idx]) | |
imagem = Image.open(imagem_path).convert('RGB') | |
# Aplicar as transformações (se houver) | |
if self.transform: | |
imagem = self.transform(imagem) | |
# Obter o rótulo correspondente | |
rotulo = self.rotulos[idx] | |
return imagem, rotulo | |
# Definir as transformações que serão aplicadas nas imagens | |
transformacoes = transforms.Compose([ | |
transforms.Resize((224, 224)), # Redimensionar para 224x224 | |
transforms.ToTensor(), # Converter para tensor | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalização para imagens RGB | |
]) | |
# Caminho para o diretório de imagens e o arquivo de rótulos | |
imagens_dir = 'caminho/para/imagens' # Substitua pelo diretório real | |
rotulos_file = 'caminho/para/rotulos.csv' # Substitua pelo caminho real | |
# Criar o dataset | |
dataset = VehicleDataset(imagens_dir, rotulos_file, transform=transformacoes) | |
# Criar o DataLoader | |
dataloader = DataLoader(dataset, batch_size=32, shuffle=True) | |
# Exemplo de iteração no DataLoader | |
for imagens, rotulos in dataloader: | |
print(imagens.shape) # Imprime o tamanho do batch de imagens | |
print(rotulos) # Imprime os rótulos correspondentes |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment