Skip to content

Instantly share code, notes, and snippets.

@manuel-delverme
Created April 22, 2018 09:04
Show Gist options
  • Save manuel-delverme/a1b6b93bd5b4d607920b045b039fcb98 to your computer and use it in GitHub Desktop.
Save manuel-delverme/a1b6b93bd5b4d607920b045b039fcb98 to your computer and use it in GitHub Desktop.
zero_training
def train_network(samples, neural_network, nr_epochs=10, batch_size=64):
optimizer = optim.Adam(neural_network.parameters())
neural_network.train()
for epoch_nr in range(nr_epochs):
sample_ids = np.random.shuffle(range(len(samples)))
for start in range(0, len(samples) // batch_size, batch_size):
mini_batch = samples[sample_ids[start: start + batch_size]]
boards, pis, vs = zip(*mini_batch)
boards = torch.FloatTensor(np.array(boards).astype(np.float64))
target_pis = torch.FloatTensor(np.array(pis))
target_vs = torch.FloatTensor(np.array(vs).astype(np.float64))
boards, target_pis, target_vs = Variable(boards), Variable(target_pis), Variable(target_vs)
out_pi, out_v = neural_network(boards)
l_pi = loss_pi(target_pis, out_pi)
l_v = loss_v(target_vs, out_v)
total_loss = l_pi + l_v
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment