Created
April 18, 2019 05:25
-
-
Save theeluwin/9bf5f53fd581c6e5e953743973578b5f to your computer and use it in GitHub Desktop.
A challenge.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # https://twitter.com/CppDDoLgi/status/1118731208790138880 | |
| # a = 6 b = 3 일 때 2 반환하고 | |
| # a = 7 b = 3 일 때 3 반환하는거 구현하고 싶은데 | |
| import multiprocessing | |
| import torch.nn as nn | |
| from tqdm import tqdm | |
| from torch import Tensor | |
| from torch.utils.data import Dataset | |
| from torch.utils.data import DataLoader | |
| from torch.optim import Adam | |
| from torch.nn import CrossEntropyLoss | |
| class Ddolgi(nn.Module): | |
| def __init__(self): | |
| super(Ddolgi, self).__init__() | |
| self.network = nn.Sequential( | |
| nn.Linear(2, 128), | |
| nn.LeakyReLU(0.2), | |
| nn.Linear(128, 64), | |
| nn.LeakyReLU(0.2), | |
| nn.Linear(64, 10) | |
| ) | |
| def forward(self, x): | |
| return self.network(x) | |
| def infer(self, a, b): | |
| x = Tensor([a, b]) | |
| x = x.unsqueeze(0) | |
| x = x.cuda() | |
| l = self.forward(x) | |
| _, c = l.max(dim=1) | |
| return int(c) | |
| class DdolgiDataset(Dataset): | |
| def __init__(self, target='train'): | |
| self.data = [ | |
| (6, 3, 2), | |
| (7, 3, 3), | |
| ] | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| a, b, c = self.data[idx] | |
| return Tensor([a, b]), c | |
| def main(): | |
| model = Ddolgi() | |
| model = model.cuda() | |
| optim = Adam(model.parameters()) | |
| dataset = DdolgiDataset() | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=2, | |
| shuffle=True, | |
| pin_memory=True, | |
| num_workers=multiprocessing.cpu_count() | |
| ) | |
| losser = CrossEntropyLoss() | |
| model.train() | |
| for epoch in range(100): | |
| for x, c in dataloader: | |
| x = x.cuda() | |
| c = c.cuda() | |
| l = model(x) | |
| loss = losser(l, c) | |
| optim.zero_grad() | |
| loss.backward() | |
| optim.step() | |
| print(f"Epoch {epoch}: {loss.item()}") | |
| model.eval() | |
| print("a = 6 b = 3 일 때 " + str(model.infer(6, 3)) + " 반환하고") | |
| print("a = 7 b = 3 일 때 " + str(model.infer(7, 3)) + " 반환하는거 구현하고 싶은데") | |
| if __name__ == '__main__': | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment