Skip to content

Instantly share code, notes, and snippets.

@reuben
Created May 6, 2025 07:32
Show Gist options
  • Save reuben/088160021ad121d69cd11c997c3e92b9 to your computer and use it in GitHub Desktop.
Save reuben/088160021ad121d69cd11c997c3e92b9 to your computer and use it in GitHub Desktop.
from torch import nn
from torch.utils.data import DataLoader, Dataset
class MyModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
class MyDataset(Dataset):
def __init__(self):
super().__init__()
def __len__(self):
raise NotImplementedError()
def __getitem__(self, idx):
raise NotImplementedError()
def train_loop(model, dataloader, criterion, optimizer):
model.train()
running_loss = 0.0
for sample in dataloader:
optimizer.zero_grad()
# TODO
loss.backward()
optimizer.step()
running_loss += loss.item()
epoch_loss = running_loss / len(dataloader)
return epoch_loss
def main():
num_epochs = 10
train_dataset = MyDataset()
train_loader = DataLoader(...)
# Model, loss, optimizer
model = MyModel()
criterion = None # TODO
optimizer = None # TODO
# Training loop
for epoch in range(num_epochs):
train_loss = train_loop(model, train_loader, criterion, optimizer)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment