Skip to content

Instantly share code, notes, and snippets.

@geyang
Created May 13, 2020 00:02
Show Gist options
  • Save geyang/00a6a4b2a1e0dd7a94cf9b44b4d3995d to your computer and use it in GitHub Desktop.
Save geyang/00a6a4b2a1e0dd7a94cf9b44b4d3995d to your computer and use it in GitHub Desktop.
diff for the mac implementation
diff --git a/mac/mac_vanilla.py b/mac/mac_vanilla.py
index 1565db7..28d5970 100644
--- a/mac/mac_vanilla.py
+++ b/mac/mac_vanilla.py
@@ -1,3 +1,5 @@
+from copy import deepcopy
+
import torch
from torch import nn
from torch import optim
@@ -109,14 +111,9 @@ def run(deps=None, **args):
logger.log_line(net, file="models/net.txt")
if Args.load_checkpoint:
logger.load_module(net, Args.load_checkpoint)
- net_running = MACNetwork(n_vocab=Args.n_words,
- dim=Args.dim,
- classes=t.n_answers,
- net_length=Args.net_length,
- dropout=Args.dropout,
- self_attention=Args.write_self_attention,
- memory_gate=Args.write_memory_gate).to(Args.device)
- momentum_pull(net_running, net, 0)
+
+ net_running = deepcopy(net)
+ # momentum_pull(net_running, net, 0)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=Args.lr)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment