Created
June 16, 2022 12:43
-
-
Save hexagit/247c69ee28630034e0ae791ad9cc94a4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ライブラリの読み込み | |
import os | |
import torch | |
import torch.nn as nn | |
from torch.utils.data import DataLoader | |
from torchvision import datasets, transforms | |
from torchvision.transforms import ToTensor, Lambda, Compose | |
# 定義 | |
class NNModel(nn.Module): | |
def __init__(self, x, y): | |
super(NNModel, self).__init__() | |
self.flatten = nn.Flatten() | |
self.linear_relu_stack = nn.Sequential( | |
nn.Linear(x*y, 120), | |
nn.ReLU(), | |
nn.Linear(120, 60), | |
nn.ReLU(), | |
nn.Linear(60, 10), | |
) | |
def forward(self, x): | |
x = self.flatten(x) | |
logits = self.linear_relu_stack(x) | |
return logits | |
def Training(dataLoader, model, lossFunc, optimizer): | |
size = len(dataLoader.dataset) | |
for batch, (X, y) in enumerate(dataLoader): | |
# 損失誤差算出 | |
pred = model(X) | |
loss = lossFunc(pred, y) | |
# 誤差逆伝播 | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
def Test(dataloader, model, lossFunc): | |
size = len(dataloader.dataset) | |
model.eval() | |
testLoss, correct = 0, 0 | |
with torch.no_grad(): | |
for X, y in dataloader: | |
pred = model(X) | |
testLoss += lossFunc(pred, y).item() | |
correct += (pred.argmax(1) == y).type(torch.float).sum().item() | |
testLoss /= size | |
correct /= size | |
print(f"Accuracy: {(100*correct):>0.1f}%, Avg loss: {testLoss:>8f}") | |
# データロード | |
dataPath = os.path.dirname(__file__) + '\data' | |
transform = Compose([ToTensor(), transforms.Normalize((0.5,), (0.5,))]) # 正規化用処理 | |
trainDataset = datasets.MNIST( | |
root=dataPath, | |
train=True, | |
download=True, | |
transform=transform | |
) | |
trainDataLoader = DataLoader( | |
dataset=trainDataset, | |
batch_size=100, | |
shuffle=True, | |
drop_last=True | |
) | |
# モデル作成 | |
model = NNModel(28, 28) | |
lossFunc = nn.CrossEntropyLoss() | |
optimizer = torch.optim.SGD(model.parameters(), lr=0.005) | |
for epoch in range(50): | |
# 学習 | |
Training(trainDataLoader, model, lossFunc, optimizer) | |
# テスト | |
Test(trainDataLoader, model, lossFunc) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment