This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding=utf-8 | |
# Copyright 2019 The SEED Authors | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
net = nn.Sequential( | |
nn.Conv2d(1, 64, 3), | |
nn.BatchNorm2d(64, momentum=1, affine=True), | |
nn.ReLU(inplace=True), | |
nn.MaxPool2d(2, 2), | |
nn.Conv2d(64, 64, 3), | |
nn.BatchNorm2d(64, momentum=1, affine=True), | |
nn.ReLU(inplace=True), | |
nn.MaxPool2d(2, 2), | |
nn.Conv2d(64, 64, 3), |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
index ede2865..de5eb9f 100755 | |
--- a/examples/maml-omniglot.py | |
+++ b/examples/maml-omniglot.py | |
@@ -30,6 +30,7 @@ import higher | |
from omniglot_loaders import OmniglotNShot | |
+ | |
def main(): | |
argparser = argparse.ArgumentParser() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import dmc2gym | |
import numpy as np | |
import gym | |
import sys | |
seed = int(sys.argv[1]) | |
env = dmc2gym.make( | |
'point_mass', | |
'easy', | |
seed, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import dmc2gym | |
import numpy as np | |
import gym | |
import sys | |
seed = int(sys.argv[1]) | |
env = dmc2gym.make( | |
'point_mass', | |
'easy', | |
seed, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.nn.init | |
from torch.autograd import Variable | |
from models.utils import * | |
class LayerNormGRUCell(nn.GRUCell): | |
def __init__(self, input_size, hidden_size, bias=True): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
""" | |
PyTorch implementation of DQN | |
Paper: https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf | |
""" | |
import argparse | |
import gym | |
from gym import wrappers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import pdb | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.autograd import Variable | |
import torch.nn.functional as F | |
import numpy as np |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/local/bin/python | |
""" | |
Q-learning with value fucntion approximation | |
""" | |
import argparse | |
import numpy as np | |
import matplotlib | |
from matplotlib import pyplot as plt | |
from mpl_toolkits.mplot3d import Axes3D |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/local/bin/python | |
""" | |
Q-learning - off policy TD(0) learning. | |
Q(S, A) <- Q(S, A) + alpha * ((R + gamma * max(Q(S', A'))) - Q(S, A)) | |
A ~ e-greedy from pi(A|S) | |
""" | |
import argparse | |
import numpy as np |
NewerOlder