Created
December 17, 2022 13:16
-
-
Save kasuganosora/96bdd7c5aed53d3a104e99264827afbb to your computer and use it in GitHub Desktop.
教会AI 加减乘除
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import random | |
import math | |
# Set the device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Set the random seed for reproducibility | |
torch.manual_seed(42) | |
# Set the hyperparameters | |
input_size = 3 | |
hidden_size = 100 | |
num_layers = 2 | |
num_classes = 1 | |
batch_size = 1000 | |
num_epochs = 100000 | |
learning_rate = 0.001 | |
class RNNModel(nn.Module): | |
def __init__(self, input_size, hidden_size, num_layers, num_classes): | |
super(RNNModel, self).__init__() | |
self.hidden_size = hidden_size | |
self.num_layers = num_layers | |
self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True) | |
self.fc = nn.Linear(hidden_size, num_classes) | |
def forward(self, x): | |
# Initialize the hidden state | |
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) | |
# Reshape the input data to have three dimensions | |
# Forward pass | |
out, _ = self.rnn(x, h0) | |
out = self.fc(out[:, -1, :]) | |
return out | |
# Initialize the model and optimizer | |
model = RNNModel(input_size, hidden_size, num_layers, num_classes).to(device) | |
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) | |
loss_fn = nn.MSELoss() | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import model | |
import os | |
# 检查文件是否存在 | |
if os.path.exists('model.pt'): | |
# 加载模型的权重 | |
model.model = torch.load('model.pt') | |
else: | |
print("model file is not exists") | |
os.exit(1) | |
# 1+1 = ? | |
inputs = torch.tensor([1, 1, 0], dtype=torch.float).to(model.device).unsqueeze(0).unsqueeze(0) | |
output = model.model(inputs) | |
result = output.item() | |
print(result) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import os | |
import torch.nn as nn | |
import random | |
import math | |
import model | |
class ArithmeticDataset(torch.utils.data.Dataset): | |
def __init__(self, num_samples=10000, min_value=0, max_value=10, operations=['add', 'sub', 'mul', 'div']): | |
self.num_samples = num_samples | |
self.min_value = min_value | |
self.max_value = max_value | |
self.operations = operations | |
# Create a dictionary that maps each operation to a unique integer | |
self.op_to_int = {op: float(i) for i, op in enumerate(self.operations)} | |
def __len__(self): | |
return self.num_samples | |
def __getitem__(self, idx): | |
# Generate random numbers and an operation | |
a = random.randint(self.min_value, self.max_value) | |
b = random.randint(self.min_value, self.max_value) | |
op = random.choice(self.operations) | |
# Compute the result of the operation | |
if op == 'add': | |
result = a + b | |
elif op == 'sub': | |
result = a - b | |
elif op == 'mul': | |
result = a * b | |
else: # division | |
if b == 0: | |
result = 0 | |
else: | |
# Make sure the division is exact | |
result = int(a / b) if a % b == 0 else 0 | |
op_int = self.op_to_int[op] | |
# Return the input data and the result as tensors | |
return torch.tensor([a, b, op_int], dtype=torch.float), torch.tensor([result], dtype=torch.float) | |
# Create the dataset and dataloader | |
dataset = ArithmeticDataset() | |
dataloader = torch.utils.data.DataLoader(dataset, batch_size=model.batch_size, shuffle=True) | |
# 检查文件是否存在 | |
if os.path.exists('model_weights.pt'): | |
# 加载模型的权重 | |
model.model.load_state_dict(torch.load('model_weights.pt'), strict=False) | |
# Train the model | |
for epoch in range(model.num_epochs): | |
for (x, y) in dataloader: | |
# Skip examples with a value of None | |
if y is None: | |
continue | |
# Reshape the input and labels | |
x = x.view(model.batch_size, 1, 3) | |
x = x.to(model.device) | |
y = y.to(model.device) | |
# Forward pass | |
output = model.model(x) | |
loss = model.loss_fn(output, y) | |
# Backward and optimize | |
model.optimizer.zero_grad() | |
loss.backward() | |
model.optimizer.step() | |
print(f'Epoch [{epoch+1}/{model.num_epochs}], Loss: {loss.item():.4f}') | |
if epoch % 100 is 0 : | |
print("save model by epoch {}".format(epoch)) | |
torch.save(model.model, 'model.pt') | |
torch.save(model.model.state_dict(), 'model_weights.pt') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment