Skip to content

Instantly share code, notes, and snippets.

@ishikawash
Last active August 26, 2025 11:54
Show Gist options
  • Select an option

  • Save ishikawash/ddc576892e26264bd0da8366aa5d9148 to your computer and use it in GitHub Desktop.

Select an option

Save ishikawash/ddc576892e26264bd0da8366aa5d9148 to your computer and use it in GitHub Desktop.
Pytorch: CNN example
import timeit
import itertools
from pathlib import Path
from dataclasses import dataclass
from argparse import ArgumentParser
import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# ref: https://docs.pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html
class LeNet(nn.Module):
def __init__(self):
super().__init__()
self.network = nn.Sequential(
# C1:
# output=(N, 6, 28, 28)
nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5),
nn.ReLU(),
# S2:
# output=(N, 6, 14, 14)
nn.MaxPool2d(kernel_size=2),
# C3:
# output=(N, 16, 10, 10)
nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
nn.ReLU(),
# S4:
# output=(N, 16, 5, 5)
nn.MaxPool2d(kernel_size=2),
# Flatten:
# output=(N, 400)
nn.Flatten(start_dim=1),
# F5:
# output=(N, 120)
nn.Linear(in_features=16 * 5 * 5, out_features=120),
nn.ReLU(),
# F6:
# output=(N, 84)
nn.Linear(in_features=120, out_features=84),
nn.ReLU(),
# F7:
# output=(N, 10)
nn.Linear(in_features=84, out_features=10),
)
def forward(self, x):
logits = self.network(x)
return logits
def begin_train_model(
model: nn.Module,
dataloader: DataLoader,
loss_fn: nn.CrossEntropyLoss,
optimizer: torch.optim.SGD,
device: str,
):
model_on_device = model.to(device)
model_on_device.train() # ?
for data in dataloader:
inputs = data[0].to(device)
labels = data[1].to(device)
optimizer.zero_grad()
outputs = model_on_device(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
yield loss
def train_model(model: nn.Module, data_loader: DataLoader, device: str, epochs: int):
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
for epoch in range(epochs):
total_loss = 0.0
for loss in begin_train_model(model, data_loader, loss_fn, optimizer, device):
total_loss += loss.item()
print(f"[{epoch + 1}] loss: {total_loss / len(data_loader):.3f}")
def test_model(model: nn.Module, data_loader: DataLoader, device: str):
model_on_device = model.to(device)
model.eval()
loss_fn = nn.CrossEntropyLoss()
total_losses = 0.0
total_corrections = 0.0
with torch.no_grad():
for data in data_loader:
inputs = data[0].to(device)
labels = data[1].to(device)
outputs = model_on_device(inputs)
total_losses += loss_fn(outputs, labels).item()
total_corrections += (
(outputs.argmax(1) == labels).type(torch.float).sum().item()
)
average_loss = total_losses / len(data_loader)
accuracy = total_corrections / len(data_loader.dataset)
print(f"Accuracy: {(100*accuracy):>0.1f}%, Loss: {average_loss:>8f}")
def doit(model: nn.Module, data_loader: DataLoader):
classes = (
"plane",
"car",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
)
model.eval()
with torch.no_grad():
for data in itertools.islice(iter(data_loader), 3):
inputs = data[0]
labels = data[1]
outputs = model(inputs)
answers = outputs.argmax(1)
for i in range(data_loader.batch_size):
a = labels[i]
b = answers[i]
print(classes[a], classes[b], int(a == b))
def query_available_device():
return (
torch.accelerator.current_accelerator().type
if torch.accelerator.is_available()
else "cpu"
)
@dataclass
class CommandArguments:
epochs: int
batch_size: int
force_cpu: bool
model_file: Path
def parse_arguments():
default_model_file_path = Path.cwd().joinpath("model.pth")
parser = ArgumentParser()
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--batch-size", type=int, default=4)
parser.add_argument("--force-cpu", action="store_true")
parser.add_argument("--model-file", type=Path, default=default_model_file_path)
return parser.parse_args(namespace=CommandArguments)
def main():
args = parse_arguments()
device_name = "cpu" if args.force_cpu else query_available_device()
print(f"Device: {device_name}")
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
]
)
batch_size = args.batch_size
trainset = torchvision.datasets.CIFAR10(
root="data", train=True, download=True, transform=transform
)
train_dataloader = torch.utils.data.DataLoader(
trainset, batch_size=batch_size, shuffle=True, num_workers=2
)
testset = torchvision.datasets.CIFAR10(
root="data", train=False, download=True, transform=transform
)
test_dataloader = torch.utils.data.DataLoader(
testset, batch_size=batch_size, shuffle=False, num_workers=2
)
model_file_path = args.model_file
if not model_file_path.exists():
print(f"{model_file_path} is not found.")
model = LeNet()
train_model(model, train_dataloader, device_name, args.epochs)
torch.save(model.state_dict(), model_file_path)
model = LeNet()
model.load_state_dict(torch.load(model_file_path, weights_only=True))
test_model(model, test_dataloader, device_name)
model = LeNet()
model.load_state_dict(torch.load(model_file_path, weights_only=True))
doit(model, test_dataloader)
if __name__ == "__main__":
elapsed_time = timeit.timeit(main, number=1)
print(f"Elapsed: {elapsed_time}")
@AdelMmdi
Copy link

can receive news from webpage or what

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment