Last active
June 15, 2025 12:07
-
-
Save qiaoxu123/d35d35414df45158ac06699e9fea13cf to your computer and use it in GitHub Desktop.
手写数字识别
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision import datasets, transforms | |
from torchvision.transforms import functional as TF | |
import matplotlib.pyplot as plt | |
import tkinter as tk | |
from PIL import Image, ImageDraw, ImageOps | |
import warnings | |
# 1. Set up device (CPU or local GPU) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
# 2. Define your model | |
class MyModel(nn.Module): | |
def __init__(self, num_classes=10): | |
super(MyModel, self).__init__() # fixed super() call | |
# Feature extractor | |
self.features = nn.Sequential( | |
nn.Conv2d(1, 32, kernel_size=3, padding=0), | |
nn.ReLU(inplace=True), | |
nn.MaxPool2d(kernel_size=2, stride=2), | |
nn.Conv2d(32, 64, kernel_size=3, padding=0), | |
nn.ReLU(inplace=True), | |
nn.MaxPool2d(kernel_size=2, stride=2), | |
) | |
# Classifier | |
self.classifier = nn.Sequential( | |
nn.Flatten(), | |
nn.Linear(64 * 5 * 5, 64), | |
nn.ReLU(inplace=True), | |
nn.Linear(64, num_classes) | |
) | |
def forward(self, x): | |
x = self.features(x) | |
x = self.classifier(x) | |
return x | |
# 3. Instantiate and move model to device | |
model = MyModel(num_classes=10).to(device) | |
# 4. (Re)define your loss, optimizer, and hyperparameters | |
loss_fn = nn.CrossEntropyLoss() | |
learning_rate = 0.1 | |
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) | |
epochs = 20 | |
# 5. Training and testing loops (now using `device`) | |
def train(dataloader, model, loss_fn, optimizer): | |
model.train() | |
size = len(dataloader.dataset) | |
num_batches = len(dataloader) | |
train_loss = 0.0 | |
train_acc = 0.0 | |
for inputs, labels in dataloader: | |
inputs, labels = inputs.to(device), labels.to(device) | |
optimizer.zero_grad() | |
pred = model(inputs) | |
loss = loss_fn(pred, labels) | |
loss.backward() | |
optimizer.step() | |
train_loss += loss.item() | |
train_acc += (pred.argmax(1) == labels).float().sum().item() | |
train_loss /= num_batches | |
train_acc /= size | |
return train_acc, train_loss | |
def test(dataloader, model, loss_fn): | |
model.eval() | |
size = len(dataloader.dataset) | |
num_batches = len(dataloader) | |
test_loss = 0.0 | |
test_acc = 0.0 | |
with torch.no_grad(): | |
for inputs, labels in dataloader: | |
inputs, labels = inputs.to(device), labels.to(device) | |
pred = model(inputs) | |
loss = loss_fn(pred, labels) | |
test_loss += loss.item() | |
test_acc += (pred.argmax(1) == labels).float().sum().item() | |
test_loss /= num_batches | |
test_acc /= size | |
return test_acc, test_loss | |
# 6. Initial all the parameters | |
model_path = "learn_mnist.pth" | |
if os.path.exists(model_path): | |
model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True)) | |
model.eval() | |
print("模型已加载,无需重新训练。") | |
else: | |
# 数据准备 | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.1307,), (0.3081,)) | |
]) | |
train_dataset = datasets.MNIST('', train=True, download=True, transform=transform) | |
test_dataset = datasets.MNIST('', train=False, transform=transform) | |
train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True) | |
test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False) | |
train_acc_history = [] | |
train_loss_history = [] | |
test_acc_history = [] | |
test_loss_history = [] | |
# 7. Run | |
for epoch in range(1, epochs + 1): | |
t_acc, t_loss = train(train_dl, model, loss_fn, optimizer) | |
v_acc, v_loss = test(test_dl, model, loss_fn) | |
train_acc_history.append(t_acc) | |
train_loss_history.append(t_loss) | |
test_acc_history.append(v_acc) | |
test_loss_history.append(v_loss) | |
print(f"Epoch {epoch:2d}: " | |
f"Train Acc {t_acc*100:5.1f}% Loss {t_loss:.3f} | " | |
f"Val Acc {v_acc*100:5.1f}% Loss {v_loss:.3f}") | |
print("Training complete.") | |
torch.save(model.state_dict(), model_path) | |
print("模型已训练并保存。") | |
# 7. Plotting | |
warnings.filterwarnings("ignore") | |
plt.rcParams['font.sans-serif'] = ['SimHei'] | |
plt.rcParams['axes.unicode_minus'] = False | |
plt.rcParams['figure.dpi'] = 100 | |
epochs_range = range(1, epochs + 1) | |
plt.figure(figsize=(12, 3)) | |
plt.subplot(1, 2, 1) | |
plt.plot(epochs_range, train_acc_history, label='Training Acc') | |
plt.plot(epochs_range, test_acc_history, label='Validation Acc') | |
plt.legend(loc='lower right') | |
plt.title('Accuracy') | |
plt.subplot(1, 2, 2) | |
plt.plot(epochs_range, train_loss_history, label='Training Loss') | |
plt.plot(epochs_range, test_loss_history, label='Validation Loss') | |
plt.legend(loc='upper right') | |
plt.title('Loss') | |
plt.show() | |
# 8. Tkinter | |
def predict_digit(img): | |
img = img.resize((28, 28)).convert('L') | |
img = ImageOps.invert(img) | |
img = TF.to_tensor(img).unsqueeze(0) | |
img = TF.normalize(img, [0.1307], [0.3081]) | |
img = img.to(device) | |
with torch.no_grad(): | |
output = model(img) | |
pred = torch.argmax(output, dim=1) | |
return pred.item() | |
# 手写板类 | |
class App(tk.Tk): | |
def __init__(self): | |
super().__init__() | |
self.title("手写数字识别") | |
self.canvas = tk.Canvas(self, width=200, height=200, bg="white") | |
self.canvas.pack() | |
self.image = Image.new("RGB", (200, 200), "white") | |
self.draw = ImageDraw.Draw(self.image) | |
self.canvas.bind("<B1-Motion>", self.paint) | |
tk.Button(self, text="识别", command=self.recognize).pack() | |
tk.Button(self, text="清除", command=self.clear).pack() | |
self.result = tk.Label(self, text="", font=("Helvetica", 20)) | |
self.result.pack() | |
def paint(self, event): | |
x, y = event.x, event.y | |
r = 8 | |
self.canvas.create_oval(x - r, y - r, x + r, y + r, fill='black') | |
self.draw.ellipse([x - r, y - r, x + r, y + r], fill='black') | |
def recognize(self): | |
digit = predict_digit(self.image) | |
self.result.config(text=f"识别结果:{digit}") | |
def clear(self): | |
self.canvas.delete("all") | |
self.draw.rectangle([0, 0, 200, 200], fill="white") | |
self.result.config(text="") | |
# 启动GUI | |
App().mainloop() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from torchvision import datasets, transforms | |
from torchvision.transforms import functional as TF | |
import tkinter as tk | |
from PIL import Image, ImageDraw, ImageOps | |
import numpy as np | |
# 设备设置 | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model_path = "lenet5_mnist.pth" | |
# 定义LeNet-5模型 | |
class LeNet5(nn.Module): | |
def __init__(self): | |
super(LeNet5, self).__init__() | |
self.conv1 = nn.Conv2d(1,6,5) | |
self.conv2 = nn.Conv2d(6,16,5) | |
self.fc1 = nn.Linear(256,120) | |
self.fc2 = nn.Linear(120, 84) | |
self.fc3 = nn.Linear(84,10) | |
def forward(self, x): | |
x = F.max_pool2d(F.relu(self.conv1(x)), 2) | |
x = F.max_pool2d(F.relu(self.conv2(x)), 2) | |
x = x.view(-1, 256) | |
x = F.relu(self.fc1(x)) | |
x = F.relu(self.fc2(x)) | |
x = self.fc3(x) | |
return x | |
# 实例化模型 | |
model = LeNet5().to(device) | |
# 如果模型文件存在,则加载 | |
if os.path.exists(model_path): | |
model.load_state_dict(torch.load(model_path, map_location=device)) | |
model.eval() | |
print("模型已加载,无需重新训练。") | |
else: | |
# 数据准备 | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.1307,), (0.3081,)) | |
]) | |
train_dataset = datasets.MNIST('', train=True, download=True, transform=transform) | |
test_dataset = datasets.MNIST('', train=False, transform=transform) | |
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True) | |
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False) | |
# 优化器与损失函数 | |
optimizer = optim.Adam(model.parameters(), lr=0.001) | |
criterion = nn.CrossEntropyLoss() | |
# 训练模型 | |
print("开始训练模型...") | |
epochs = 5 | |
for epoch in range(epochs): | |
running_loss = 0.0 | |
model.train() | |
for inputs, labels in train_loader: | |
inputs, labels = inputs.to(device), labels.to(device) | |
optimizer.zero_grad() | |
outputs = model(inputs) | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
running_loss += loss.item() | |
print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}') | |
# 保存模型 | |
torch.save(model.state_dict(), model_path) | |
print("模型已训练并保存。") | |
# 测试准确率 | |
model.eval() | |
correct = 0 | |
total = 0 | |
with torch.no_grad(): | |
for images, labels in test_loader: | |
images, labels = images.to(device), labels.to(device) | |
outputs = model(images) | |
_, predicted = torch.max(outputs, 1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
print(f'测试准确率: {100 * correct / total:.2f}%') | |
# 可视化手写识别(Tkinter) | |
def predict_digit(img): | |
img = img.resize((28, 28)).convert('L') | |
img = ImageOps.invert(img) | |
img = TF.to_tensor(img).unsqueeze(0) | |
img = TF.normalize(img, [0.1307], [0.3081]) | |
img = img.to(device) | |
with torch.no_grad(): | |
output = model(img) | |
pred = torch.argmax(output, dim=1) | |
return pred.item() | |
# 手写板类 | |
class App(tk.Tk): | |
def __init__(self): | |
super().__init__() | |
self.title("手写数字识别") | |
self.canvas = tk.Canvas(self, width=200, height=200, bg="white") | |
self.canvas.pack() | |
self.image = Image.new("RGB", (200, 200), "white") | |
self.draw = ImageDraw.Draw(self.image) | |
self.canvas.bind("<B1-Motion>", self.paint) | |
tk.Button(self, text="识别", command=self.recognize).pack() | |
tk.Button(self, text="清除", command=self.clear).pack() | |
self.result = tk.Label(self, text="", font=("Helvetica", 20)) | |
self.result.pack() | |
def paint(self, event): | |
x, y = event.x, event.y | |
r = 8 | |
self.canvas.create_oval(x - r, y - r, x + r, y + r, fill='black') | |
self.draw.ellipse([x - r, y - r, x + r, y + r], fill='black') | |
def recognize(self): | |
digit = predict_digit(self.image) | |
self.result.config(text=f"识别结果:{digit}") | |
def clear(self): | |
self.canvas.delete("all") | |
self.draw.rectangle([0, 0, 200, 200], fill="white") | |
self.result.config(text="") | |
# 启动GUI | |
App().mainloop() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment