Created
April 24, 2022 08:03
-
-
Save TMPxyz/7dd1a847027e36046f3758a257440ad2 to your computer and use it in GitHub Desktop.
simple classifier
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# %% | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.utils.data as td | |
# %% [markdown] | |
# ### Gen data | |
# %% | |
INPUT_NUM = 3 | |
OUTPUT_NUM = 8 | |
# %% | |
def calc_idx(a,b,c): | |
return 4*(a>=0.5) + 2*(b>=0.5) + 1*(c>=0.5) # accuracy 98% | |
# return 4* ( 0.25 <= a <= 0.75 ) + 2*( b >= 0.5 ) + 1*( c >= 0.5 ) # accuracy 50% | |
# return 4* ((a%0.5) > 0.25) + 2*( (b%0.5) > 0.25 ) + 1*( (c%0.5) > 0.25 ) # accuracy 25% | |
# %% | |
x = np.random.random(size=(1000, 3)) | |
y = [] | |
for a,b,c in x: | |
idx = calc_idx(a,b,c) | |
# lst = [0.] * 8 | |
# lst[idx] = 1.0 | |
# y.append(lst) # mse Loss | |
y.append(idx) # entropy Loss | |
split = int(0.8 * len(x)) | |
otrain_x, otrain_y = x[:split], y[:split] | |
otest_x, otest_y = x[split:], y[split:] | |
display(x[:5], y[:5]) | |
# %% [markdown] | |
# ### Dataset | |
# %% | |
train_x = torch.tensor(otrain_x, dtype=torch.float) | |
train_y = torch.tensor(otrain_y, dtype=torch.long) | |
train_ds = td.TensorDataset(train_x, train_y) | |
train_ld = td.DataLoader(train_ds, batch_size=8) | |
test_x = torch.tensor(otest_x, dtype=torch.float) | |
test_y = torch.tensor(otest_y, dtype=torch.long) | |
test_ds = td.TensorDataset(test_x, test_y) | |
test_ld = td.DataLoader(test_ds, batch_size=8) | |
# %% [markdown] | |
# ### NN | |
# %% | |
hl = 15 | |
class Net(nn.Module): | |
def __init__(self, input_num, output_num): | |
super().__init__() | |
self._input = nn.Linear(input_num, hl) | |
self._fc1 = nn.Linear(hl, hl) | |
self._fc2 = nn.Linear(hl, output_num) | |
def forward(self, inputs): | |
x = inputs | |
x = torch.relu( self._input(x) ) # using relu here is faster than using sigmoid | |
x = torch.relu( self._fc1(x) ) | |
x = torch.sigmoid( self._fc2(x) ) # ReLU gives extremely poor results, use sigmoid | |
return x | |
model = Net(INPUT_NUM, OUTPUT_NUM) | |
print(model) | |
# %% [markdown] | |
# ### Train | |
# %% | |
def train(model, data_loader, optimizer): | |
model.train() | |
train_loss = 0 | |
for batch, tensor in enumerate(data_loader): | |
data, target = tensor | |
# forward | |
optimizer.zero_grad() | |
out = model(data) | |
loss = loss_fn(out, target) | |
train_loss += loss.item() | |
# backward | |
loss.backward() | |
optimizer.step() | |
avg_loss = train_loss / (batch+1) | |
print(f"Training set, avg loss {avg_loss:.6f}") | |
return avg_loss | |
def test(model, data_loader): | |
model.eval() | |
test_loss = 0 | |
correct = 0 | |
batch_count = 0 | |
with torch.no_grad(): | |
for _, tensor in enumerate(data_loader, 1): | |
batch_count += 1 | |
data, target = tensor | |
out = model(data) | |
test_loss += loss_fn(out, target).item() | |
predicted = torch.argmax(out, dim=-1) | |
# labels = torch.argmax(target, dim=-1) # if use MSELoss | |
labels = target # use CrossEntropyLoss | |
correct += torch.sum(predicted==labels).item() | |
avg_loss = test_loss / batch_count | |
accu = correct/len(data_loader.dataset) | |
print(f"avg_loss : {avg_loss}, accuracy = {accu}") | |
return avg_loss, accu | |
loss_fn = nn.CrossEntropyLoss() | |
EPOCHS = 200 | |
epoch_nums = [] | |
training_loss = [] | |
validation_loss = [] | |
accus = [] | |
learning_rate = 1e-3 | |
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) # RMSprop gives same accuracy as Adam(W) | |
optimizer.zero_grad() | |
for epoch in range(1, 1+EPOCHS): | |
print(f"epoch {epoch}") | |
train_loss = train(model, train_ld, optimizer) | |
test_loss, accuracy = test(model, test_ld) | |
epoch_nums.append(epoch) | |
training_loss.append(train_loss) | |
validation_loss.append(test_loss) | |
accus.append(accuracy) | |
# %% | |
import matplotlib.pyplot as plt | |
%matplotlib inline | |
plt.rcParams["figure.figsize"]=16,8 | |
ax = plt.subplot(121) | |
ax.plot(epoch_nums, training_loss, label="train loss") | |
ax.plot(epoch_nums, validation_loss, label="validation loss") | |
plt.legend() | |
ax = plt.subplot(122) | |
ax.plot(epoch_nums, accus, label="accuracy") | |
plt.legend() | |
plt.show() | |
# %% | |
x = np.random.random(size=(100, 3)) | |
y = torch.tensor( list(map(lambda x0: calc_idx(*x0), x)) ) | |
model.eval() | |
with torch.no_grad(): | |
output = model(torch.tensor(x, dtype=torch.float)) | |
got = torch.tensor( list(map( torch.argmax, output )) ) | |
print( torch.sum( y==got ).item(), "/", 100 ) | |
torch.set_printoptions(sci_mode=False) | |
display(x[:5], output[:5]) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment