Skip to content

Instantly share code, notes, and snippets.

@qiaoxu123
Last active June 15, 2025 12:07
Show Gist options
  • Save qiaoxu123/d35d35414df45158ac06699e9fea13cf to your computer and use it in GitHub Desktop.
Save qiaoxu123/d35d35414df45158ac06699e9fea13cf to your computer and use it in GitHub Desktop.
手写数字识别
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchvision.transforms import functional as TF
import matplotlib.pyplot as plt
import tkinter as tk
from PIL import Image, ImageDraw, ImageOps
import warnings
# 1. Set up device (CPU or local GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# 2. Define your model
class MyModel(nn.Module):
def __init__(self, num_classes=10):
super(MyModel, self).__init__() # fixed super() call
# Feature extractor
self.features = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=0),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(32, 64, kernel_size=3, padding=0),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
)
# Classifier
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(64 * 5 * 5, 64),
nn.ReLU(inplace=True),
nn.Linear(64, num_classes)
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x
# 3. Instantiate and move model to device
model = MyModel(num_classes=10).to(device)
# 4. (Re)define your loss, optimizer, and hyperparameters
loss_fn = nn.CrossEntropyLoss()
learning_rate = 0.1
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
epochs = 20
# 5. Training and testing loops (now using `device`)
def train(dataloader, model, loss_fn, optimizer):
model.train()
size = len(dataloader.dataset)
num_batches = len(dataloader)
train_loss = 0.0
train_acc = 0.0
for inputs, labels in dataloader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
pred = model(inputs)
loss = loss_fn(pred, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
train_acc += (pred.argmax(1) == labels).float().sum().item()
train_loss /= num_batches
train_acc /= size
return train_acc, train_loss
def test(dataloader, model, loss_fn):
model.eval()
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss = 0.0
test_acc = 0.0
with torch.no_grad():
for inputs, labels in dataloader:
inputs, labels = inputs.to(device), labels.to(device)
pred = model(inputs)
loss = loss_fn(pred, labels)
test_loss += loss.item()
test_acc += (pred.argmax(1) == labels).float().sum().item()
test_loss /= num_batches
test_acc /= size
return test_acc, test_loss
# 6. Initial all the parameters
model_path = "learn_mnist.pth"
if os.path.exists(model_path):
model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
model.eval()
print("模型已加载,无需重新训练。")
else:
# 数据准备
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('', train=False, transform=transform)
train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)
train_acc_history = []
train_loss_history = []
test_acc_history = []
test_loss_history = []
# 7. Run
for epoch in range(1, epochs + 1):
t_acc, t_loss = train(train_dl, model, loss_fn, optimizer)
v_acc, v_loss = test(test_dl, model, loss_fn)
train_acc_history.append(t_acc)
train_loss_history.append(t_loss)
test_acc_history.append(v_acc)
test_loss_history.append(v_loss)
print(f"Epoch {epoch:2d}: "
f"Train Acc {t_acc*100:5.1f}% Loss {t_loss:.3f} | "
f"Val Acc {v_acc*100:5.1f}% Loss {v_loss:.3f}")
print("Training complete.")
torch.save(model.state_dict(), model_path)
print("模型已训练并保存。")
# 7. Plotting
warnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 100
epochs_range = range(1, epochs + 1)
plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_acc_history, label='Training Acc')
plt.plot(epochs_range, test_acc_history, label='Validation Acc')
plt.legend(loc='lower right')
plt.title('Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss_history, label='Training Loss')
plt.plot(epochs_range, test_loss_history, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Loss')
plt.show()
# 8. Tkinter
def predict_digit(img):
img = img.resize((28, 28)).convert('L')
img = ImageOps.invert(img)
img = TF.to_tensor(img).unsqueeze(0)
img = TF.normalize(img, [0.1307], [0.3081])
img = img.to(device)
with torch.no_grad():
output = model(img)
pred = torch.argmax(output, dim=1)
return pred.item()
# 手写板类
class App(tk.Tk):
def __init__(self):
super().__init__()
self.title("手写数字识别")
self.canvas = tk.Canvas(self, width=200, height=200, bg="white")
self.canvas.pack()
self.image = Image.new("RGB", (200, 200), "white")
self.draw = ImageDraw.Draw(self.image)
self.canvas.bind("<B1-Motion>", self.paint)
tk.Button(self, text="识别", command=self.recognize).pack()
tk.Button(self, text="清除", command=self.clear).pack()
self.result = tk.Label(self, text="", font=("Helvetica", 20))
self.result.pack()
def paint(self, event):
x, y = event.x, event.y
r = 8
self.canvas.create_oval(x - r, y - r, x + r, y + r, fill='black')
self.draw.ellipse([x - r, y - r, x + r, y + r], fill='black')
def recognize(self):
digit = predict_digit(self.image)
self.result.config(text=f"识别结果:{digit}")
def clear(self):
self.canvas.delete("all")
self.draw.rectangle([0, 0, 200, 200], fill="white")
self.result.config(text="")
# 启动GUI
App().mainloop()
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.transforms import functional as TF
import tkinter as tk
from PIL import Image, ImageDraw, ImageOps
import numpy as np
# 设备设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = "lenet5_mnist.pth"
# 定义LeNet-5模型
class LeNet5(nn.Module):
def __init__(self):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(1,6,5)
self.conv2 = nn.Conv2d(6,16,5)
self.fc1 = nn.Linear(256,120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84,10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), 2)
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, 256)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# 实例化模型
model = LeNet5().to(device)
# 如果模型文件存在,则加载
if os.path.exists(model_path):
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
print("模型已加载,无需重新训练。")
else:
# 数据准备
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)
# 优化器与损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
# 训练模型
print("开始训练模型...")
epochs = 5
for epoch in range(epochs):
running_loss = 0.0
model.train()
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}')
# 保存模型
torch.save(model.state_dict(), model_path)
print("模型已训练并保存。")
# 测试准确率
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'测试准确率: {100 * correct / total:.2f}%')
# 可视化手写识别(Tkinter)
def predict_digit(img):
img = img.resize((28, 28)).convert('L')
img = ImageOps.invert(img)
img = TF.to_tensor(img).unsqueeze(0)
img = TF.normalize(img, [0.1307], [0.3081])
img = img.to(device)
with torch.no_grad():
output = model(img)
pred = torch.argmax(output, dim=1)
return pred.item()
# 手写板类
class App(tk.Tk):
def __init__(self):
super().__init__()
self.title("手写数字识别")
self.canvas = tk.Canvas(self, width=200, height=200, bg="white")
self.canvas.pack()
self.image = Image.new("RGB", (200, 200), "white")
self.draw = ImageDraw.Draw(self.image)
self.canvas.bind("<B1-Motion>", self.paint)
tk.Button(self, text="识别", command=self.recognize).pack()
tk.Button(self, text="清除", command=self.clear).pack()
self.result = tk.Label(self, text="", font=("Helvetica", 20))
self.result.pack()
def paint(self, event):
x, y = event.x, event.y
r = 8
self.canvas.create_oval(x - r, y - r, x + r, y + r, fill='black')
self.draw.ellipse([x - r, y - r, x + r, y + r], fill='black')
def recognize(self):
digit = predict_digit(self.image)
self.result.config(text=f"识别结果:{digit}")
def clear(self):
self.canvas.delete("all")
self.draw.rectangle([0, 0, 200, 200], fill="white")
self.result.config(text="")
# 启动GUI
App().mainloop()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment