Created
December 28, 2021 05:21
-
-
Save ejmejm/5592b577e95cf77de0c77059e4995905 to your computer and use it in GitHub Desktop.
VISR.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/ejmejm/5592b577e95cf77de0c77059e4995905/visr.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "_29r97oJh7zK", | |
"outputId": "6fa74ed6-f6fa-4e22-928e-7e98ac54fbc5" | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"--2021-12-28 02:17:42-- http://www.atarimania.com/roms/Roms.rar\n", | |
"Resolving www.atarimania.com (www.atarimania.com)... 195.154.81.199\n", | |
"Connecting to www.atarimania.com (www.atarimania.com)|195.154.81.199|:80... connected.\n", | |
"HTTP request sent, awaiting response... 200 OK\n", | |
"Length: 11128004 (11M) [application/x-rar-compressed]\n", | |
"Saving to: ‘Roms.rar.1’\n", | |
"\n", | |
"Roms.rar.1 100%[===================>] 10.61M 261KB/s in 44s \n", | |
"\n", | |
"2021-12-28 02:18:26 (247 KB/s) - ‘Roms.rar.1’ saved [11128004/11128004]\n", | |
"\n", | |
"\n", | |
"UNRAR 5.50 freeware Copyright (c) 1993-2017 Alexander Roshal\n", | |
"\n", | |
"\n", | |
"Extracting from Roms.rar\n", | |
"\n", | |
"\n", | |
"Would you like to replace the existing file HC ROMS.zip\n", | |
"11826711 bytes, modified on 2019-12-22 11:24\n", | |
"with a new one\n", | |
"11826711 bytes, modified on 2019-12-22 11:24\n", | |
"\n", | |
"[Y]es, [N]o, [A]ll, n[E]ver, [R]ename, [Q]uit \n", | |
"User break\n", | |
"\n", | |
"User break\n", | |
"replace ROMS/128 in 1 Game Select ROM (128 in 1) (Unknown) ~.bin? [y]es, [n]o, [A]ll, [N]one, [r]ename: Requirement already satisfied: einops in /usr/local/lib/python3.7/dist-packages (0.3.2)\n", | |
"Traceback (most recent call last):\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pkg_resources/__init__.py\", line 3021, in _dep_map\n", | |
" return self.__dep_map\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pkg_resources/__init__.py\", line 2815, in __getattr__\n", | |
" raise AttributeError(attr)\n", | |
"AttributeError: _DistInfoDistribution__dep_map\n", | |
"\n", | |
"During handling of the above exception, another exception occurred:\n", | |
"\n", | |
"Traceback (most recent call last):\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_internal/cli/base_command.py\", line 180, in _main\n", | |
" status = self.run(options, args)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_internal/cli/req_command.py\", line 199, in wrapper\n", | |
" return func(self, options, args)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_internal/commands/install.py\", line 385, in run\n", | |
" conflicts = self._determine_conflicts(to_install)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_internal/commands/install.py\", line 515, in _determine_conflicts\n", | |
" return check_install_conflicts(to_install)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_internal/operations/check.py\", line 103, in check_install_conflicts\n", | |
" package_set, _ = create_package_set_from_installed()\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_internal/operations/check.py\", line 45, in create_package_set_from_installed\n", | |
" package_set[name] = PackageDetails(dist.version, dist.requires())\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pkg_resources/__init__.py\", line 2736, in requires\n", | |
" dm = self._dep_map\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pkg_resources/__init__.py\", line 3023, in _dep_map\n", | |
" self.__dep_map = self._compute_dependencies()\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pkg_resources/__init__.py\", line 3033, in _compute_dependencies\n", | |
" reqs.extend(parse_requirements(req))\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pkg_resources/__init__.py\", line 3094, in parse_requirements\n", | |
" yield Requirement(line)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pkg_resources/__init__.py\", line 3101, in __init__\n", | |
" super(Requirement, self).__init__(requirement_string)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/packaging/requirements.py\", line 113, in __init__\n", | |
" req = REQUIREMENT.parseString(requirement_string)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pyparsing.py\", line 1943, in parseString\n", | |
" loc, tokens = self._parse(instring, 0)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pyparsing.py\", line 1683, in _parseNoCache\n", | |
" loc, tokens = self.parseImpl(instring, preloc, doActions)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pyparsing.py\", line 4069, in parseImpl\n", | |
" loc, exprtokens = e._parse(instring, loc, doActions)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pyparsing.py\", line 1683, in _parseNoCache\n", | |
" loc, tokens = self.parseImpl(instring, preloc, doActions)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pyparsing.py\", line 4462, in parseImpl\n", | |
" return self.expr._parse(instring, loc, doActions, callPreParse=False)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pyparsing.py\", line 1683, in _parseNoCache\n", | |
" loc, tokens = self.parseImpl(instring, preloc, doActions)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pyparsing.py\", line 4069, in parseImpl\n", | |
" loc, exprtokens = e._parse(instring, loc, doActions)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pyparsing.py\", line 1683, in _parseNoCache\n", | |
" loc, tokens = self.parseImpl(instring, preloc, doActions)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pyparsing.py\", line 4781, in parseImpl\n", | |
" return super(ZeroOrMore, self).parseImpl(instring, loc, doActions)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pyparsing.py\", line 4697, in parseImpl\n", | |
" loc, tokens = self_expr_parse(instring, loc, doActions, callPreParse=False)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pyparsing.py\", line 1683, in _parseNoCache\n", | |
" loc, tokens = self.parseImpl(instring, preloc, doActions)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pyparsing.py\", line 4254, in parseImpl\n", | |
" ret = e._parse(instring, loc, doActions)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pyparsing.py\", line 1683, in _parseNoCache\n", | |
" loc, tokens = self.parseImpl(instring, preloc, doActions)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pyparsing.py\", line 4069, in parseImpl\n", | |
" loc, exprtokens = e._parse(instring, loc, doActions)\n", | |
"KeyboardInterrupt\n", | |
"\n", | |
"During handling of the above exception, another exception occurred:\n", | |
"\n", | |
"Traceback (most recent call last):\n", | |
" File \"/usr/local/bin/pip3\", line 8, in <module>\n", | |
" sys.exit(main())\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_internal/cli/main.py\", line 71, in main\n", | |
" return command.main(cmd_args)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_internal/cli/base_command.py\", line 104, in main\n", | |
" return self._main(args)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_internal/cli/base_command.py\", line 212, in _main\n", | |
" logger.critical(\"Operation cancelled by user\")\n", | |
" File \"/usr/lib/python3.7/logging/__init__.py\", line 1415, in critical\n", | |
" def critical(self, msg, *args, **kwargs):\n", | |
"KeyboardInterrupt\n", | |
"fatal: destination path 'gym-gridworld' already exists and is not an empty directory.\n", | |
"/content/gym-gridworld\n", | |
"Obtaining file:///content/gym-gridworld\n", | |
"Requirement already satisfied: gym in /usr/local/lib/python3.7/dist-packages (from gym-gridworld==0.0.1) (0.17.3)\n", | |
"Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from gym->gym-gridworld==0.0.1) (1.4.1)\n", | |
"Requirement already satisfied: numpy>=1.10.4 in /usr/local/lib/python3.7/dist-packages (from gym->gym-gridworld==0.0.1) (1.19.5)\n", | |
"Requirement already satisfied: cloudpickle<1.7.0,>=1.2.0 in /usr/local/lib/python3.7/dist-packages (from gym->gym-gridworld==0.0.1) (1.3.0)\n", | |
"Requirement already satisfied: pyglet<=1.5.0,>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from gym->gym-gridworld==0.0.1) (1.5.0)\n", | |
"Requirement already satisfied: future in /usr/local/lib/python3.7/dist-packages (from pyglet<=1.5.0,>=1.4.0->gym->gym-gridworld==0.0.1) (0.16.0)\n", | |
"Traceback (most recent call last):\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pkg_resources/__init__.py\", line 3021, in _dep_map\n", | |
" return self.__dep_map\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pkg_resources/__init__.py\", line 2815, in __getattr__\n", | |
" raise AttributeError(attr)\n", | |
"AttributeError: _DistInfoDistribution__dep_map\n", | |
"\n", | |
"During handling of the above exception, another exception occurred:\n", | |
"\n", | |
"Traceback (most recent call last):\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_internal/cli/base_command.py\", line 180, in _main\n", | |
" status = self.run(options, args)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_internal/cli/req_command.py\", line 199, in wrapper\n", | |
" return func(self, options, args)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_internal/commands/install.py\", line 385, in run\n", | |
" conflicts = self._determine_conflicts(to_install)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_internal/commands/install.py\", line 515, in _determine_conflicts\n", | |
" return check_install_conflicts(to_install)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_internal/operations/check.py\", line 103, in check_install_conflicts\n", | |
" package_set, _ = create_package_set_from_installed()\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_internal/operations/check.py\", line 45, in create_package_set_from_installed\n", | |
" package_set[name] = PackageDetails(dist.version, dist.requires())\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pkg_resources/__init__.py\", line 2736, in requires\n", | |
" dm = self._dep_map\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pkg_resources/__init__.py\", line 3023, in _dep_map\n", | |
" self.__dep_map = self._compute_dependencies()\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pkg_resources/__init__.py\", line 3033, in _compute_dependencies\n", | |
" reqs.extend(parse_requirements(req))\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pkg_resources/__init__.py\", line 3094, in parse_requirements\n", | |
" yield Requirement(line)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pkg_resources/__init__.py\", line 3101, in __init__\n", | |
" super(Requirement, self).__init__(requirement_string)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/packaging/requirements.py\", line 113, in __init__\n", | |
" req = REQUIREMENT.parseString(requirement_string)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pyparsing.py\", line 1943, in parseString\n", | |
" loc, tokens = self._parse(instring, 0)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pyparsing.py\", line 1683, in _parseNoCache\n", | |
" loc, tokens = self.parseImpl(instring, preloc, doActions)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pyparsing.py\", line 4069, in parseImpl\n", | |
" loc, exprtokens = e._parse(instring, loc, doActions)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pyparsing.py\", line 1683, in _parseNoCache\n", | |
" loc, tokens = self.parseImpl(instring, preloc, doActions)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pyparsing.py\", line 4254, in parseImpl\n", | |
" ret = e._parse(instring, loc, doActions)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pyparsing.py\", line 1683, in _parseNoCache\n", | |
" loc, tokens = self.parseImpl(instring, preloc, doActions)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pyparsing.py\", line 4052, in parseImpl\n", | |
" loc, resultlist = self.exprs[0]._parse(instring, loc, doActions, callPreParse=False)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pyparsing.py\", line 1687, in _parseNoCache\n", | |
" loc, tokens = self.parseImpl(instring, preloc, doActions)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pyparsing.py\", line 4462, in parseImpl\n", | |
" return self.expr._parse(instring, loc, doActions, callPreParse=False)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pyparsing.py\", line 1687, in _parseNoCache\n", | |
" loc, tokens = self.parseImpl(instring, preloc, doActions)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pyparsing.py\", line 2899, in parseImpl\n", | |
" raise ParseException(instring, loc, self.errmsg, self)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_vendor/pyparsing.py\", line 304, in __init__\n", | |
" def __init__(self, pstr, loc=0, msg=None, elem=None):\n", | |
"KeyboardInterrupt\n", | |
"\n", | |
"During handling of the above exception, another exception occurred:\n", | |
"\n", | |
"Traceback (most recent call last):\n", | |
" File \"/usr/local/bin/pip3\", line 8, in <module>\n", | |
" sys.exit(main())\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_internal/cli/main.py\", line 71, in main\n", | |
" return command.main(cmd_args)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_internal/cli/base_command.py\", line 104, in main\n", | |
" return self._main(args)\n", | |
" File \"/usr/local/lib/python3.7/dist-packages/pip/_internal/cli/base_command.py\", line 212, in _main\n", | |
" logger.critical(\"Operation cancelled by user\")\n", | |
"KeyboardInterrupt\n" | |
] | |
} | |
], | |
"source": [ | |
"!wget http://www.atarimania.com/roms/Roms.rar\n", | |
"!unrar e Roms.rar\n", | |
"!unzip ROMS.zip > /dev/null\n", | |
"!python -m atari_py.import_roms ./ROMS > /dev/null\n", | |
"!pip install einops\n", | |
"!git clone https://github.com/xinleipan/gym-gridworld.git\n", | |
"%cd gym-gridworld\n", | |
"!pip install -e ." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "kHfTkNhEZ23Q" | |
}, | |
"outputs": [], | |
"source": [ | |
"import copy\n", | |
"import gym\n", | |
"from gym.wrappers import AtariPreprocessing, FrameStack, TransformObservation\n", | |
"from gym.wrappers import TimeLimit\n", | |
"import gym_gridworld\n", | |
"import matplotlib.pyplot as plt\n", | |
"import numpy as np\n", | |
"import torch\n", | |
"from torch import nn\n", | |
"import torch.nn.functional as F\n", | |
"import einops\n", | |
"import seaborn as sns\n", | |
"import cv2\n", | |
"from tqdm.notebook import tqdm\n", | |
"\n", | |
"%matplotlib inline\n", | |
"sns.set()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "oy8T5RcWaIEw" | |
}, | |
"outputs": [], | |
"source": [ | |
"### Gridworld Env ###\n", | |
"\n", | |
"class GridWorldWrapper(gym.ObservationWrapper):\n", | |
" def __init__(self, env):\n", | |
" super().__init__(env)\n", | |
"\n", | |
" obs_shape = (1, 16, 16)\n", | |
" self.observation_space = gym.spaces.Box(\n", | |
" low=0, high=1, shape=obs_shape, dtype=np.float32)\n", | |
"\n", | |
" def observation(self, observation):\n", | |
" observation = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY)\n", | |
" observation = cv2.resize(observation, self.observation_space.shape[1:],\n", | |
" interpolation=cv2.INTER_AREA)\n", | |
" observation = np.expand_dims(observation, 0)\n", | |
" return observation\n", | |
"\n", | |
"\n", | |
"class GridWorldWrapper(gym.ObservationWrapper):\n", | |
" def __init__(self, env):\n", | |
" super().__init__(env)\n", | |
"\n", | |
" obs_shape = (1, 16, 16)\n", | |
" self.observation_space = gym.spaces.Box(\n", | |
" low=0, high=1, shape=obs_shape, dtype=np.float32)\n", | |
"\n", | |
" def observation(self, observation):\n", | |
" observation = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY)\n", | |
" observation = cv2.resize(observation, self.observation_space.shape[1:],\n", | |
" interpolation=cv2.INTER_AREA)\n", | |
" observation = np.expand_dims(observation, 0)\n", | |
" return observation\n", | |
"\n", | |
"class SimpleMapWrapper(gym.Wrapper):\n", | |
" def __init__(self, env, randomized=False):\n", | |
" super().__init__(env)\n", | |
" self.random_start = randomized\n", | |
" self._reset_start_map()\n", | |
"\n", | |
" def _reset_start_map(self):\n", | |
" map = np.ones((16, 16), dtype=np.int64)\n", | |
" map [1:-1, 1:-1] = 0\n", | |
"\n", | |
" get_random_pos = lambda: np.random.randint(1, 15)\n", | |
" if self.random_start:\n", | |
" map[get_random_pos(), get_random_pos()] = 3\n", | |
" new_pos = (get_random_pos(), get_random_pos())\n", | |
" while map[new_pos[0], new_pos[1]] == 3:\n", | |
" new_pos = (get_random_pos(), get_random_pos())\n", | |
" map[new_pos[0], new_pos[1]] = 4\n", | |
" else:\n", | |
" map[-3, 2] = 3\n", | |
" map[-6, 4] = 4\n", | |
"\n", | |
" uenv = self.unwrapped\n", | |
" uenv.start_grid_map = map\n", | |
" uenv.current_grid_map = copy.deepcopy(uenv.start_grid_map) # current grid map\n", | |
" uenv.observation = uenv._gridmap_to_observation(uenv.start_grid_map)\n", | |
" uenv.grid_map_shape = uenv.start_grid_map.shape\n", | |
"\n", | |
" uenv.agent_start_state, uenv.agent_target_state = \\\n", | |
" uenv._get_agent_start_target_state(uenv.start_grid_map)\n", | |
" uenv.agent_state = copy.deepcopy(uenv.agent_start_state)\n", | |
"\n", | |
" def reset(self):\n", | |
" if self.random_start:\n", | |
" self._reset_start_map()\n", | |
" return super().reset()\n", | |
"\n", | |
"def create_gridworld_env(max_steps=1000):\n", | |
" global N_FRAME_STACK\n", | |
" N_FRAME_STACK = 1\n", | |
"\n", | |
" env = gym.make('gridworld-v0')\n", | |
" env = GridWorldWrapper(env)\n", | |
" env = TimeLimit(env, max_steps)\n", | |
" env = TransformObservation(env, torch.FloatTensor)\n", | |
" return env\n", | |
"\n", | |
"def create_simple_gridworld_env(randomized=False, max_steps=1000):\n", | |
" env = create_gridworld_env(max_steps)\n", | |
" env = SimpleMapWrapper(env, randomized)\n", | |
" return env\n", | |
"\n", | |
"\n", | |
"### Atari Env ###\n", | |
"\n", | |
"atari_wrappers = [\n", | |
" lambda env: AtariPreprocessing(env, scale_obs=True),\n", | |
" lambda env: FrameStack(env, N_FRAME_STACK),\n", | |
" lambda env: TransformObservation(env, torch.FloatTensor)\n", | |
"]\n", | |
"\n", | |
"def create_breakout_env():\n", | |
" global N_FRAME_STACK\n", | |
" N_FRAME_STACK = 4\n", | |
"\n", | |
" env = gym.make('BreakoutNoFrameskip-v4')\n", | |
" for wrapper in atari_wrappers:\n", | |
" env = wrapper(env)\n", | |
" return env\n", | |
"\n", | |
"# env = create_breakout_env()\n", | |
"# env = create_gridworld_env()\n", | |
"env = create_simple_gridworld_env(randomized=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "z1m8gi53cTpt" | |
}, | |
"outputs": [], | |
"source": [ | |
"class SFModel(nn.Module):\n", | |
" def __init__(self, obs_dim, n_acts, embed_size=128):\n", | |
" super().__init__()\n", | |
" self.embed_size = embed_size\n", | |
" self.n_acts = n_acts\n", | |
"\n", | |
" # Change network size depending on the size of the input\n", | |
" if obs_dim[1] <= 42 and obs_dim[2] <= 42:\n", | |
" create_convs = self.create_small_convs\n", | |
" linear_dim = 128\n", | |
" else:\n", | |
" create_convs = self.create_large_convs\n", | |
" linear_dim = 512\n", | |
"\n", | |
" test_input = torch.zeros(obs_dim)[None]\n", | |
"\n", | |
" # Feature layers\n", | |
" self.conv_feature_layers = create_convs(obs_dim[0])\n", | |
" feature_embed_dim = self.conv_feature_layers(test_input)\n", | |
" self.linear_feature_layers = nn.Sequential(\n", | |
" nn.Linear(feature_embed_dim.shape[1], linear_dim),\n", | |
" nn.ReLU(),\n", | |
" nn.Linear(linear_dim, embed_size)\n", | |
" )\n", | |
" self.feature_layers = nn.Sequential(\n", | |
" self.conv_feature_layers,\n", | |
" self.linear_feature_layers\n", | |
" )\n", | |
"\n", | |
" # Sucessor Feature Layers\n", | |
" self.conv_sf_layers = create_convs(obs_dim[0])\n", | |
" sf_embed_dim = self.conv_sf_layers(test_input)\n", | |
" self.linear_sf_layers = nn.Sequential(\n", | |
" nn.Linear(sf_embed_dim.shape[1] + embed_size, linear_dim),\n", | |
" nn.ReLU(),\n", | |
" nn.Linear(linear_dim, linear_dim),\n", | |
" nn.ReLU(),\n", | |
" nn.Linear(linear_dim, n_acts * embed_size)\n", | |
" )\n", | |
" self.sf_layers = nn.Sequential(\n", | |
" self.conv_sf_layers,\n", | |
" self.linear_sf_layers\n", | |
" )\n", | |
"\n", | |
" def create_large_convs(self, input_dim):\n", | |
" return nn.Sequential(\n", | |
" nn.Conv2d(input_dim, 32, 8, 4),\n", | |
" nn.ReLU(),\n", | |
" nn.Conv2d(32, 64, 4, 2),\n", | |
" nn.ReLU(),\n", | |
" nn.Conv2d(64, 64, 3, 1),\n", | |
" nn.ReLU(),\n", | |
" nn.Flatten()\n", | |
" )\n", | |
" \n", | |
" def create_small_convs(self, input_dim):\n", | |
" return nn.Sequential(\n", | |
" nn.Conv2d(input_dim, 8, 4, 2),\n", | |
" nn.ReLU(),\n", | |
" nn.Conv2d(8, 16, 3, 1),\n", | |
" nn.ReLU(),\n", | |
" nn.Flatten()\n", | |
" )\n", | |
"\n", | |
" def forward(self, obs, goal_vector=None):\n", | |
" features = self.feature_layers(obs)\n", | |
" features = F.normalize(features, dim=1)\n", | |
"\n", | |
" if goal_vector is None:\n", | |
" sfs = None\n", | |
" else:\n", | |
" sf_z = self.conv_sf_layers(obs)\n", | |
" sf_z = torch.cat((sf_z, goal_vector), dim=-1)\n", | |
" sfs = self.linear_sf_layers(sf_z)\n", | |
" sfs = sfs.reshape(-1, self.n_acts, self.embed_size)\n", | |
" return features, sfs\n", | |
"\n", | |
" def train(self, batch_data, optimizer, gamma=0.99, reward_loss_weight=1):\n", | |
" device = next(self.parameters()).device\n", | |
"\n", | |
" batch_data = np.array(batch_data)\n", | |
" obs, acts, next_obs, rewards, dones, skill_vectors = \\\n", | |
" [torch.from_numpy(np.stack(batch_data[:, i])).to(device) \\\n", | |
" for i in range(batch_data.shape[1])]\n", | |
"\n", | |
" _, next_sfs = self(next_obs, skill_vectors)\n", | |
" optimizer.zero_grad()\n", | |
" features, sfs = self(obs, skill_vectors)\n", | |
" q_vals = torch.bmm(next_sfs, skill_vectors.unsqueeze(2))\n", | |
" next_acts = q_vals.argmax(dim=1).unsqueeze(-1).repeat(1, 1, next_sfs.shape[2])\n", | |
" selected_next_sfs = next_sfs.gather(dim=1, index=next_acts)\n", | |
" selected_next_sfs = selected_next_sfs.squeeze(1)\n", | |
" terminals = 1 - dones.int()\n", | |
" y = features + gamma * selected_next_sfs * terminals.unsqueeze(1)\n", | |
"\n", | |
" sf_idxs = acts.unsqueeze(1).unsqueeze(2).repeat(1, 1, sfs.shape[2])\n", | |
" selected_sfs = sfs.gather(dim=1, index=sf_idxs)\n", | |
" selected_sfs = selected_sfs.squeeze(1)\n", | |
"\n", | |
" td_loss = (y.detach() - selected_sfs) ** 2\n", | |
" td_loss = torch.mean(torch.sum(td_loss, dim=1))\n", | |
"\n", | |
" reward_loss = 1.0 - torch.bmm(features.unsqueeze(1), skill_vectors.unsqueeze(2))\n", | |
" reward_loss = reward_loss.mean()\n", | |
"\n", | |
" total_loss = reward_loss_weight * reward_loss + td_loss\n", | |
"\n", | |
" optimizer.zero_grad()\n", | |
" total_loss.backward()\n", | |
" optimizer.step()\n", | |
"\n", | |
" return total_loss, reward_loss, td_loss" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "-7OjQ92s6oTN" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Adapted from: https://github.com/deepmind/deepmind-research/blob/master/visr/VISR_ICLR2020.ipynb\n", | |
"def sample_sphere(dim, n=1):\n", | |
" unnormed = np.random.randn(n, dim)\n", | |
" arr = unnormed / np.linalg.norm(unnormed, axis=-1, keepdims=True)\n", | |
" return torch.FloatTensor(arr)\n", | |
"\n", | |
"def sample_batch(buffer, n):\n", | |
" data_idxs = np.random.choice(range(len(buffer)), size=n, replace=False)\n", | |
" batch_data = []\n", | |
" for i in data_idxs:\n", | |
" batch_data.append(buffer[i])\n", | |
" return batch_data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "O6qRisLM6wm-" | |
}, | |
"outputs": [], | |
"source": [ | |
"DEVICE = 'cuda'\n", | |
"\n", | |
"n_episodes = 100000\n", | |
"update_freq = 64\n", | |
"batch_size = 256\n", | |
"print_freq = update_freq * 500\n", | |
"max_steps = 100\n", | |
"randomize_env = True\n", | |
"exp_buffer_size = int(1e6)\n", | |
"\n", | |
"start_act_epsilon = 1\n", | |
"end_act_epsilon = 0.1\n", | |
"act_epsilon_anneal_steps = int(1e5)\n", | |
"\n", | |
"start_reward_weight = 10\n", | |
"end_reward_weight = 10\n", | |
"reward_anneal_steps = int(1e6)\n", | |
"\n", | |
"embed_dim = 32\n", | |
"lr = 1e-4\n", | |
"\n", | |
"\n", | |
"def anneal_func(step, start, end, n_steps):\n", | |
" if step >= n_steps:\n", | |
" return end\n", | |
" return start - ((start - end) * (step / n_steps))\n", | |
"\n", | |
"get_act_epsilon = lambda step: anneal_func(\n", | |
" step, start_act_epsilon, end_act_epsilon, act_epsilon_anneal_steps)\n", | |
"get_reward_weight = lambda step: anneal_func(\n", | |
" step, start_reward_weight, end_reward_weight, reward_anneal_steps)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "mGMc0FPPh2IV" | |
}, | |
"outputs": [], | |
"source": [ | |
"env = create_simple_gridworld_env(randomize_env, max_steps)\n", | |
"\n", | |
"obs_dim = env.observation_space.shape\n", | |
"n_acts = env.action_space.n\n", | |
"model = SFModel(obs_dim, n_acts, embed_dim)\n", | |
"model = model.to(DEVICE)\n", | |
"\n", | |
"optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n", | |
"\n", | |
"sample_skill = lambda: sample_sphere(embed_dim)[0]\n", | |
"\n", | |
"exp_buffer = [] # [[obs, act, next_obs, reward, done, skill_vector], ...]\n", | |
"\n", | |
"all_rewards = []\n", | |
"loss_hist = [[], [], []]\n", | |
"step_idx = 0" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"background_save": true, | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "__s4jY95y7T5", | |
"outputId": "357466e9-640c-4501-efa3-7dcdda3d967e" | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:81: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Step: 32000\tLoss: 24.7084\tReward Loss: 0.9345\tTD Loss: 15.3632\tEx Reward: 0.15\n", | |
"Step: 64000\tLoss: 29.1705\tReward Loss: 0.9143\tTD Loss: 20.0273\tEx Reward: 0.13\n", | |
"Step: 96000\tLoss: 11.8866\tReward Loss: 0.8425\tTD Loss: 3.4617\tEx Reward: 0.10\n", | |
"Step: 128000\tLoss: 10.2671\tReward Loss: 0.8418\tTD Loss: 1.8491\tEx Reward: 0.09\n", | |
"Step: 160000\tLoss: 9.9769\tReward Loss: 0.8453\tTD Loss: 1.5240\tEx Reward: 0.08\n", | |
"Step: 192000\tLoss: 9.9873\tReward Loss: 0.8460\tTD Loss: 1.5270\tEx Reward: 0.08\n", | |
"Step: 224000\tLoss: 9.9653\tReward Loss: 0.8477\tTD Loss: 1.4880\tEx Reward: 0.07\n", | |
"Step: 256000\tLoss: 9.9219\tReward Loss: 0.8504\tTD Loss: 1.4176\tEx Reward: 0.07\n", | |
"Step: 288000\tLoss: 9.9181\tReward Loss: 0.8491\tTD Loss: 1.4267\tEx Reward: 0.07\n", | |
"Step: 320000\tLoss: 9.8990\tReward Loss: 0.8442\tTD Loss: 1.4570\tEx Reward: 0.06\n", | |
"Step: 352000\tLoss: 9.9144\tReward Loss: 0.8383\tTD Loss: 1.5316\tEx Reward: 0.06\n", | |
"Step: 384000\tLoss: 10.0145\tReward Loss: 0.8352\tTD Loss: 1.6622\tEx Reward: 0.06\n", | |
"Step: 416000\tLoss: 10.2442\tReward Loss: 0.8338\tTD Loss: 1.9065\tEx Reward: 0.06\n", | |
"Step: 448000\tLoss: 10.7727\tReward Loss: 0.8329\tTD Loss: 2.4435\tEx Reward: 0.06\n", | |
"Step: 480000\tLoss: 14.2633\tReward Loss: 0.8334\tTD Loss: 5.9294\tEx Reward: 0.06\n", | |
"Step: 512000\tLoss: 15.3901\tReward Loss: 0.8345\tTD Loss: 7.0455\tEx Reward: 0.06\n", | |
"Step: 544000\tLoss: 16.0420\tReward Loss: 0.8348\tTD Loss: 7.6939\tEx Reward: 0.06\n", | |
"Step: 576000\tLoss: 16.4977\tReward Loss: 0.8364\tTD Loss: 8.1338\tEx Reward: 0.06\n", | |
"Step: 608000\tLoss: 16.8128\tReward Loss: 0.8389\tTD Loss: 8.4240\tEx Reward: 0.05\n", | |
"Step: 640000\tLoss: 17.1703\tReward Loss: 0.8388\tTD Loss: 8.7820\tEx Reward: 0.06\n", | |
"Step: 672000\tLoss: 17.3322\tReward Loss: 0.8384\tTD Loss: 8.9487\tEx Reward: 0.05\n", | |
"Step: 704000\tLoss: 17.4789\tReward Loss: 0.8378\tTD Loss: 9.1005\tEx Reward: 0.05\n", | |
"Step: 736000\tLoss: 17.6233\tReward Loss: 0.8374\tTD Loss: 9.2495\tEx Reward: 0.05\n" | |
] | |
} | |
], | |
"source": [ | |
"# Episode loop\n", | |
"for episode_idx in range(n_episodes):\n", | |
" obs = env.reset()\n", | |
"\n", | |
" skill_vector = sample_skill()\n", | |
"\n", | |
" ep_rewards = []\n", | |
" done = False\n", | |
" episodes_passed = 0\n", | |
" while not done:\n", | |
" # Sample an action\n", | |
" if np.random.rand() < get_act_epsilon(step_idx):\n", | |
" # Random sample\n", | |
" act = env.action_space.sample()\n", | |
" else:\n", | |
" # Sample based off of successor features\n", | |
" with torch.no_grad():\n", | |
" _, sfs = model(obs.unsqueeze(0).to(DEVICE),\n", | |
" skill_vector.unsqueeze(0).to(DEVICE))\n", | |
" sfs = sfs.cpu()\n", | |
" q_vals = torch.matmul(sfs[0], skill_vector.unsqueeze(1))\n", | |
" act = torch.argmax(q_vals).item()\n", | |
"\n", | |
" # Take a step\n", | |
" next_obs, reward, done, _ = env.step(act)\n", | |
" exp_buffer.append([np.array(copy.deepcopy(obs)), act,\n", | |
" np.array(copy.deepcopy(next_obs)),\n", | |
" reward, done, np.array(skill_vector)])\n", | |
" ep_rewards.append(reward)\n", | |
" obs = next_obs\n", | |
"\n", | |
" # Delete extra data from the experience buffer\n", | |
" exp_buffer_overflow = len(exp_buffer) - exp_buffer_size\n", | |
" if exp_buffer_overflow > 0:\n", | |
" exp_buffer = exp_buffer[exp_buffer_overflow:]\n", | |
"\n", | |
" # Update the model\n", | |
" if step_idx % update_freq == 0 and len(exp_buffer) >= batch_size:\n", | |
" batch_data = sample_batch(exp_buffer, batch_size)\n", | |
" loss, reward_loss, td_loss = model.train(\n", | |
" batch_data, optimizer, reward_loss_weight=get_reward_weight(step_idx))\n", | |
" for i, l in enumerate([loss, reward_loss, td_loss]):\n", | |
" loss_hist[i].append(l.item())\n", | |
" batch_data = []\n", | |
"\n", | |
" # Print the agent training status\n", | |
" if step_idx != 0 and step_idx % print_freq == 0:\n", | |
" lookback = print_freq // update_freq\n", | |
" print('Step: {}\\tLoss: {:.4f}\\tReward Loss: {:.4f}\\tTD Loss: {:.4f}\\tEx Reward: {:.2f}'\n", | |
" .format(step_idx, *[np.mean(loss_hist[i][-lookback:]) \\\n", | |
" for i in range(len(loss_hist))],\n", | |
" np.mean(all_rewards[-episodes_passed:])))\n", | |
" episodes_passed = 0\n", | |
" \n", | |
" step_idx += 1\n", | |
" episodes_passed += 1\n", | |
" all_rewards.append(sum(ep_rewards))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 818 | |
}, | |
"id": "43o_HMRlyre_", | |
"outputId": "f8f1eb5b-795e-47b5-a90d-1b7f3c6fb06e" | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAELCAYAAAAx94awAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deVxU5f7A8c/MAAIKwiDgkFviRqC5lKSFptdyCaMss0t1+2Vptmi3nbrlUmmYZqvm1dLq6q17vVYmuVVmZqG5oxJuuaEjIIsCss3M+f0xMjKyw8Bs3/fr5et1znnOmfme4fidZ57znOdRKYqiIIQQwq2o7R2AEEKI5ifJXwgh3JAkfyGEcEOS/IUQwg1J8hdCCDckyV8IIdyQJH8hbOyrr77ir3/9a5Vl6enpdO/eHYPB0MxRCWFNkr8QQrghSf7CJdirJi01eOGsJPkLpzV06FAWLVrE6NGj6d27NwaDgT179nDvvfdy3XXXcfvtt7Nt2zYAtm7dyujRoy3HPvTQQ9x1112W9fj4eH744QcAFi1axLBhw+jTpw+jRo3i+++/t+z31Vdfce+99zJr1iyio6P54IMPyM3NZdKkSfTt25e7776bkydP1vkcMjIymDRpEv379+eWW27hv//9r6UsJSWFMWPG0LdvXwYOHMibb74JQElJCc899xzR0dFcd9113HXXXZw7d65hH6JwWx72DkCIxvjuu+9YtGgRgYGBZGdn8+ijj/LWW28RExNDcnIyU6ZMYe3atfTu3Zvjx4+Tk5ODn58fBw8eRKPRUFBQgIeHB/v376dfv34AtG/fnuXLlxMcHMy6det4/vnn2bBhAyEhIYA5Kd922238+uuvGAwGXnrpJVq0aMGWLVtIT0/n4Ycfpl27dnWK/5lnnqFr16788ssv/Pnnnzz00EO0b9+eAQMGMHPmTP72t79xxx13UFhYyOHDhwH4+uuvKSgoYNOmTXh5efHHH3/g7e3dNB+wcFlS8xdO7YEHHkCn0+Ht7c2qVasYNGgQgwcPRq1Wc+ONNxIVFcXPP/+Mt7c3PXv2ZMeOHRw4cIAePXrQt29fdu3axZ49e+jYsSOBgYEAjBw5ktDQUNRqNaNGjaJjx46kpKRY3jMkJIQHHngADw8PPD092bBhA1OmTMHX15du3bpx55131il2vV7Prl27eO6552jRogURERGMHTuWVatWAeDh4cHJkyfJycmhZcuW9O7d27I9Ly+PEydOoNFoiIqKolWrVjb+ZIWrk5q/cGo6nc6yfObMGdatW8dPP/1k2WYwGIiOjgbg+uuv5/fffyc0NJTrr78ef39/tm/fjpeXF/3797cc880337B06VJOnz4NwMWLF8nNzbWUt23b1rKck5ODwWCwiiMsLKxOsWdmZtK6dWurxB0WFsb+/fsBmDlzJu+//z4jR46kXbt2PPnkkwwZMoS4uDjOnj3LM888w4ULF7j99tt5+umn8fT0rNP7CgGS/IWTU6lUlmWdTkdcXBxvvPFGlfv279+fxMREwsLCmDBhAq1bt+bVV1/F09OT++67D4DTp0/zyiuv8Omnn9KnTx80Gg1xcXHVvqdWq8XDwwO9Xk94eDhgrtHXRUhICOfPn6egoMDyBaDX6wkNDQWgU6dOzJs3D5PJZPl1sW3bNnx9fXnyySd58sknSU9PZ+LEiVx99dWMHTu2jp+aENLsI1zI7bffzk8//cQvv/yC0WikpKSEbdu2cfbsWQD69OnDsWPHSElJoVevXnTt2pXTp0+TkpLC9ddfD0BRUREqlQqtVgvAypUrLW3tVdFoNNxyyy18+OGHFBUVceTIEb7++us6xavT6ejTpw/z5s2jpKSEtLQ0/ve//3H77bcDsGrVKnJyclCr1fj7+wOgVqvZunUrBw8exGg00qpVKzw8PFCr5b+yqB+p+QuXodPpWLBgAXPmzOHZZ59FrVbTq1cvpk+fDoCvry+RkZF4eXnh5eUFmL8QDh8+TFBQEABdunRh/Pjx3HvvvahUKu644w769u1b4/tOnTqVl156iRtvvJHOnTszZswYSy+j2sybN49p06YRExODv78/kydPZuDAgQD88ssvJCYmUlxcTFhYGO+88w7e3t6cO3eOadOmkZGRga+vL6NGjar060SI2qhkMhchhHA/8ltRCCHckCR/IYRwQ5L8hRDCDUnyF0IINyTJXwgh3JAkfyGEcENO088/N7cQk6n+vVKDglqRnV3QBBHZlsRpO84QI0ictuQMMULzxqlWqwgMbFltudMkf5NJaVDyLz/WGUictuMMMYLEaUvOECM4TpzS7COEEG5Ikr8QQrghSf5CCOGGJPkLIYQbkuQvhBBuSJK/EEK4IUn+QlRjfOJGxidutKxn5RWx8uejmGQUdOECnKafvxBNxWA0sWWfnpheOjRVzIi1NfUsi75Ntax3bRdAr/Cg5gxRCJuTmr9wakdPn2d84kZyLhTXuN83v/zJT7vSq9xv1ZZjfL7uIM/O/82yrWLtvmLiB3h3xV7GJ27kx53plm3lvxJOZuQ39FSEaFZS8xdOZUdaJt3aB+Df0jwN48x/7QRgzpd7eHPiDVUeU7Hp5l8bDgGwJGGoZdt3yScAuFBYStqJXDqE+nHufFGtsSz//hDLvz9ktW360u1Wry2Eo5LkL+xOURROZxXSLqRVjftl5F5kwTf7AVj0/M1k5l5O0Bk5F0k7kUuPjoFWx5QZjFW+VlGJAZ8WHpXK3/piNwA9OgTU+zzK5V8sRVGwfEEJ4Ygk+Qu7e3j2TwBMHtOTPt2Cq90v5Ui2ZXnR6lR2pGValb/1xW5Lrfu9FXvZezSbuwZ3rvK1ftyZzleb/+SOm66usjztZB4AHUJacTLTPBDXtP+7Hn1OYaVmoCs99f4WAPkFIByatPkLAE5m5Ffq3VIdRVGY8el2ikoMjX7fzXvPWJY/+GqfZTkj5yK/7tMD8O/vD/Hs/F/54sfDlvIrE3+58vj3HjV/Uaz8+c8q9/tqs3n7N1uOAXBNp8Aq94vqbL6x26Vdazq29eOGa9ry6oPX1X5il2IpLjV/Rr/sPcOhU3l1Ok6I5iA1fwGYb3qW+3nPaf79w2HKDCaG9WtH/C3drPZ9+oMtXLhYxhPvbK537TY3v4Qfd6YT00tHqNaXA8dyrMpfXPgbr42P5qVFWwEICfThhwo3Vquy8NnBTHr7Z8t6dhXt9a893J//bTrKDdeEsmh15Zr7XYPDGdavlPdXplhtv21ARxQUYgd0smy7WufPJy8O4XD6ebR+LSgsNtDSx4PU47l8ujbN6vgPVu7j5j5XsfTS9o9fHIJaparxfIRoDlLzFwB00vlblj9bd5AygwmgysR74WKZZbmktOo29aoUFpfx7PxfWbP1BC8t2sp/Nx7hyh7zWXnFPDbvciJ/c9muGl/ztgEd8fLUWH0J/d9rGyrt1y64FX8fey03RLat8nU6tvWjd9c2Vq8T1VmLTwsPxt7cBZ8W1vUklUpFt/YBtAnwoWNbP9q09mHQtWH89S9drfb740QuH126TwFwsbjxv5aEsAVJ/m5KURRLggf4Lvl4tfvWNP54xURd0cqfj/JJ0uUa9vjEjUx+9xerfdb9fpIWHmq8PBt2Gfr7enLX4PBqy9sFmyeyiLpaa7X94xeG8OqD1/HRM4Mt2yrWxhc9fzMPjerB02OvrXdMt1zfniUJQ2nT2rvK8oKisiq3C9Hc6tTsc+zYMRISEsjLyyMgIIDZs2fTqVMnq32ysrKYOnUq6enpGAwGJk2aRFxcHADZ2dm89NJL6PV6DAYD0dHRvPLKK3h4SKuTvZTfZH3/qRimvPdLjfs+8tZPtAtuxYTR19CmTeUeOYqiYDCa2Hkwi/7XhKJWqSzdJx+OvabG1047mYefjycT7okkcXnNtXyAkAAfxgzuTN9uwXhorL80boxqy6/7z1rWX7q/X6UaO5hnOLr60i+dqpqtPDRqYnqF1RpLTWZPGkDCP5PJyrN+rqDgYhloqzlIiGZUpyrXtGnTiI+PZ/369cTHxzN16tRK+yQmJhIVFcXq1atZvnw577zzDnq9+YbdwoULCQ8PZ/Xq1Xz77bccOHCADRsq/zQXjVNaZmTlz0ctNxmrYzRdrvHXlvjLpWcVMG3J70z9Z3Klsq9/+ZNH5/7MotWpbDuQgT670FL24850Uo6eq/Z1sy8Uk32hhG7tA1iSMJTH74iylFXsbtn6UrfJZ+/tTf+I0EqJH8xfNGMGdUZ9qRLv7aWp07k1BZVKReKjAyzrV136FSI1f+Eoak3+2dnZpKamEhsbC0BsbCypqank5FjfqEtLSyMmJgYArVZLjx49WLt2LWD+j1BYWIjJZKK0tJSysjJCQ0NtfS5ub95/9vBd8gken7cZRVEYn7iRtdtOVNovN7+kyuM//PsgQgN9AHj8jiheG9+/0j57DmcB0LNzEN3atQYg6bfL77E4KZV/LN5mWV/+/SHeXWF9E7Um1/UIIfHRG5j/9CCe/2sfoq8J5YO/x/DO5JtYkjCU4ACfGo+PHdiJVXPjWJIwFJWdb6yqVCr6Xeq6OuWuXgCVbigLYS+1trvo9XpCQ0PRaMy1KI1GQ0hICHq9Hq328u/XyMhI1qxZQ8+ePUlPT2f37t20a9cOgMcff5zJkydz0003UVRUxH333Ue/fv3qFWhQUM0PANUkONivwcc2p8bGeSj9vGV5a5o5Sa/46Sh/izXXpotLDXh7eXDg1PlKx3p7aejYPpCPX7m1Tu/1xmM3sj31LG8s/b3O8S2bMYJWvl5o1CoKisr46ytrAIgf3sPq3Csuv/Jw1U/t1sZR/ubTHx0IYNUttrpzdWTOEKczxAiOE6fNGt0TEhKYNWsWcXFxhIWFMWDAAMsXxrp16+jevTufffYZhYWFTJgwgXXr1jFixIg6v352dkGDJj4ODvYjK8vxx1tpaJwnM/Lx0KgJCbSuES9edbmHyZzPt1v1p7/SP/7Wj/Cw1jW+f0igj9UTtdnZBXRo41tjbBNvv8bqgajSolJyikoB832C2IGduK57MB1Cbfs3cvS/+cGjWWj9vR0+znLOEKczxAjNG6daraqx0lxr8tfpdGRkZGA0GtFoNBiNRjIzM9HpdFb7abVa5s6da1mfMGECXbp0AWDZsmXMmjULtVqNn58fQ4cOZdu2bfVK/sLs07V/sHmv+V7KjT3b8uu+s7UcQbWJf2R0B77fkU7nCt08r1TxhuiVD4B5aNQsSRhq2T6k71X8tOs0AO9NuanG9m2VSsWYQVU/feuq4od15d8/HOa5Bb/xtxHdGXtLD3uHJNxYrW3+QUFBREREkJSUBEBSUhIRERFWTT4Aubm5GAzmn7bJyckcOnTIcp+gXbt2bN68GYDS0lKSk5Pp2tW6P7Som/LED1SZ+G+Mqrof+5X6dQ9m7JAuLHr+5nq3jX/49xir9cRJA/jw74N44Nbu3Nw7DH9fT1r5eNJWe/mXwVN396rXe7ii8qeOAT5fd9COkQhRx2af6dOnk5CQwIIFC/D392f27NmAuXY/ZcoUevbsSUpKCjNnzkStVhMYGMjChQvx8TE3Rbz88stMmzaN0aNHYzQaiY6O5p577mm6s3IhZQYjJhO08NLUaaTJO2I6ExzgwzdbjvHelJss48xcKbBVi3rH8vELQwjUtuR83kWr7SEVbsL+bUQP/jbico1Wxre57PE7onjinc32DkMIAFSK4hzTErlrm395k0rF5pWaXJlsi0oMPPHOZoL8vXn9kf48Pm9ztfs2Jk5H4sgxbk/LtDzx+8XrIykqrLrnlSNx5M+znDPECE7W5i/sZ8HXlwc623qg+rb9Rc/fzMZdp+nZufLTQz4tPFj0/M2WfvFenmpKy0w8fFuE7QMWtbq+RwjFo3qwdE0aZ84VEugj/wWFfciV54DGJ25EF+SLPvty80rFwcjenXIT2w5kEOjXgt5d2+ChUXPr9e2rfb2KD0R99MxgFJDBxewo4FKT25lzhQS2b23naIS7kuTvYM7mmBN+xcR/JX9fL26pIdnXRKVSIWnfvrpcZU74OeeLQZK/sBMZ2M3BrP/9ZI3lcgPV+ZWPN7T3SJadIxHuTJK/g9Gorevlw/tfruEvfHbwlbsLJ7armglphGgO0uzjYLZfkRC6tQ9g3FB5JkIIYVuS/JtARs5FCorLCA+rf3tu/qWJUj5+cQj6c4VcFdzwMY2E44od2JE1W09iNJnQqOUHuGh+kvxt7MWFv1nGcK9L+3xWXhEZORcZEuxnNbmKWqWSxO/CtP7emEwK5wtK0fpXPfGLEE1Jkr+NVZy8Y922k5w+V8Cv+87y0TODaVHF+PIvLjSPj29ARYi/uQvg3TdXPzuVcA1BlxJ+elahJH9hF/J704byL5Zarf/3pyOW8Xeqmu7w0Kk8y/L7/91DRq65e2f3CpOYCNdU3uPn3RV77RyJcFeS/G2ounF0qpNzwXqKv6Vr0gAIDax5qGTh/K7WOcaY7sJ9SfK3kfKHswAmjK48b235rFcVnS+0/qVQPgRyKx9PG0cnHI3c5BX2Jm3+NnDlgGsDItsyILKt1fZD6eetenYoisLRMxcA6B8Rwu9/SJ9vd6Uoit2nnBTuR5J/I9U0WXp5b5/yL4EJb22qcr9JcVH8/kftI3YK11RQVIafr5e9wxBuRn57NtDBk7mMT9xoNUQyVN1kc2cdZqz65JVbuCq4JZ+8OMRmMQrH9tS4PgB8v+OUnSMR7kiSfwPN/vduq/WObf34v5E9mPfkjZX2jR3QsdbXCwn05fWHo+Xnvxs5cdbc7Jf02wk7RyLckSR/G3n1wesYdG2Y1fDJ5WpK6C/G92nKsIQDeyg20t4hCDcmbf4N8OnatErbahsf/4O/x6BWqXj9sx2czbmIh0bF0L7t6NZe+vS7K7VafuUJ+5Hk3wCb956xLNd1iOWW3uZ7AbMm3kBufgneXhrLgz5C7D6cRZ+uwfYOQ7gRafapo9//yGDnQevx12eM79+g1wr0ayGJXwDQs3MQAMf1jj//rHAtkvzrwGA0sXDVAeZXmFMXoH2IDLwmGmfK3T0BWP3bcfsGItyOVD9rYDIpvLDwN6vRNsuHZOhSxRO7QtSXPOkr7EWSfw3OXSgm50KJ1bY3l+0CoL0MtyxsTMb2F81JrrQavPmvnZW2ZV+q+Ydf5d/c4QgXpb00lPdT721hfOJGzuUV2Tki4Q4k+dfgyoHXKurXPaQZIxGu7M4Y8xPgF0vMQ4W8cGmOByGakiT/GvTu0sayPP/pQVZlLTwrT8wiREME+LWwdwjCDdWpzf/YsWMkJCSQl5dHQEAAs2fPplOnTlb7ZGVlMXXqVNLT0zEYDEyaNIm4uDhL+Zo1a/joo48sIxguXbqUNm3a4MhyC0qIulrLM+N62zsU4cI8r3gqPCTQx06RCHdSp+Q/bdo04uPjiYuLY9WqVUydOpXPP//cap/ExESioqL46KOPyMnJYcyYMfTv3x+dTse+ffv48MMP+eyzzwgODiY/Px8vL8cdxXDdtpP896cjANzcO8yyfdzQLvxn4xEGXRtW3aFC1NvVOuv7R5m50uYvml6tzT7Z2dmkpqYSGxsLQGxsLKmpqeTk5Fjtl5aWRkxMDABarZYePXqwdu1aAD799FPGjx9PcLD5CUY/Pz9atHDcn7rliR+gZYVROof378DHLw7hwRHd7RGWcFGeHmpGD+xkte10VoF9ghFuo9bkr9frCQ0NRaMxt3FrNBpCQkLQ6/VW+0VGRrJmzRoUReHUqVPs3r2bM2fMwyAcPXqUU6dOcd9993HnnXeyYMECFEVpgtOxvSvn5VWrVDLyprC5Owd1ZknCUB6/IwqAMqOpliOEaByb9fNPSEhg1qxZxMXFERYWxoABAyxfGEajkYMHD7J06VJKS0t55JFHCAsL44477qjz6wcFNbxffXBw3edLNZqsv5QeHB1FcDPNqVufOO3JGeJ0hhihcpxhueauxK99uoPVb8dVdYhdOMPn6QwxguPEWWvy1+l0ZGRkYDQa0Wg0GI1GMjMz0el0VvtptVrmzp1rWZ8wYQJdunQBICwsjBEjRuDl5YWXlxd/+ctfSElJqVfyz84uwGSq/6+F4GA/srLqPm5Kpe6dZYZ6Hd9Q9Y3TXpwhTmeIEaqO85T+vGVZf/Y8mblFhLVp2dyhWXGGz9MZYoTmjVOtVtVYaa612ScoKIiIiAiSkpIASEpKIiIiAq1Wa7Vfbm4uBoO5n3JycjKHDh2yuk+wZcsWFEWhrKyMrVu30qNHjwafVFM6X2B+oreFl4bu7QOkiUc0qx4dLg/xPXHOJl75eJtlSBEhbKlOzT7Tp08nISGBBQsW4O/vz+zZswFz7X7KlCn07NmTlJQUZs6ciVqtJjAwkIULF+LjY+6ydtttt7F//35GjRqFWq3mpptu4u677266s2qEgyfzAJg4+hoZYlc0u9atKneE+G7rCR64VToZCNtSKU5y57W5mn1W/nyU75JPMHNCNLqg5vu5LT9bbccZYoTq41y15Rirthyz2lbXeSOagjN8ns4QIzhZs4+7+S7ZPJ9qkL+3nSMR7mrUDR0qPeh1JP18NXsL0TCS/C85l1fElhQ9Hdua78R7yfANwk48PTQkPjrAqra/aPUBO0YkXJEkf+BicRkvLExmyZo/OHE2Hz9fz9oPEqIZzHvyRgA6h8kossK2JPlzuamnXP7FMjtFIoS1gEs3gI+elmYfYVuS/IG1205arcv8usLRZF8xqZAQjSXJvwqd2jrGE3hCVHToVJ69QxAuxO2Tf0GRuYmnVYUB3O6/tZu9whGikut7mCcOSly+y86RCFfi9u0bi1enAuYvgYT7+nL6XGGz9u8XojaT4iLZnpZp7zCEi3H75L/vz2zA/Fh9t/bmf0I4kopDjJgUBbUMOSJswO2bfcqNuqGjvUMQolaPzP4Jgwz3LGzArZN/xf9EUZ2D7BiJEHU3cc4me4cgXIBbJ/+LxQZ7hyBEnSQ+eoO9QxAuxq2Tf3GpOfk/fFuEnSMRomYhgb4MuvbyHBp5BdLvXzSOWyf/sznmibIrztMrhKOKH3a5C/IzH/5qx0iEK3Dr5P/uir0A+MoTvcIJeHlq6NdN5pgQtuHWyf+67ub/SF2uam3nSISomyfG9LQsz/zXDl5c+BslpUY7RiSclVsn/1a+Xvj5eqJWS79p4XyOnr5AVl4xe4+es3cowgm5dfIvLjHg4yVNPsK5XK2zHnsq6bcT1ewpRPXcOvkXlRjwbiGTtgjn8sJf+1qtt2kts86J+nPr5F9casRbav7CybTw0hB5tdayvueINPuI+nPr5F9UasDHS2r+wvk8O6631TSPMuSDqC+3Tv7FJUaZuEW4BBnyQdSXeyf/UgPeUvMXLuLo6fOMT9zIqi3H7B2KcAJunfyLSo14S81fOLFXH7zOsjzzXzsBJPmLOnHp5G9SFErKqn4AxmA0UWYwSZu/cGpX6/yr3G5SlGaORDgbl07+21IzeOi1DVXeDCu+9FSk9PYRrugfi7fZOwTh4Fw6+edfLCP/YimlVdT+i0vMI3pKP3/h7CI7BVqW/XzNgxRm5Fy0VzjCSdQp+R87doxx48YxfPhwxo0bx/Hjxyvtk5WVxWOPPcbo0aMZOXIkq1atqrTPn3/+ybXXXsvs2bMbHXhdaC4N22AwVv4JXF7zlyd8hbN7ckwvXn3wOpYkDOWtSQMt20+fK7RjVMLR1Sn5T5s2jfj4eNavX098fDxTp06ttE9iYiJRUVGsXr2a5cuX884776DX6y3lRqORadOmMWzYMNtFX4vyMXsKisoqlRWVSs1fuIYWXhpL23+LCvewXv1Ymn5E9WpN/tnZ2aSmphIbGwtAbGwsqamp5OTkWO2XlpZGTEwMAFqtlh49erB27VpL+aJFi7j55pvp1KmTDcOv2dqt5jFPXqniP0FRidT8hWsa0b+DZVlu/Irq1Jr59Ho9oaGhaDTmGoVGoyEkJAS9Xo9We/kR88jISNasWUPPnj1JT09n9+7dtGvXDjB/MWzZsoXPP/+cBQsWNCjQoKBW9T7m3Pliy3JwsPVgWPmpGQC0DfWvVGYvjhJHbZwhTmeIEZomzsfG9mbd7ycB+M9PR5kyrk+jX9MZPk9niBEcJ06bVXsTEhKYNWsWcXFxhIWFMWDAADQaDWVlZbz66qu8+eabli+QhsjOLsBkql8t5vm/9mHOF7u59fr2ZGXlW5XtvzQeSmF+EVke9h/SOTjYr1KMjsgZ4nSGGKFp43z+3t7M+XIPP+9O55Z+V9HKx7PBPduc4fN0hhiheeNUq1U1VpprvRp0Oh0ZGRkYjUY0Gg1Go5HMzEx0Op3Vflqtlrlz51rWJ0yYQJcuXcjKyuLkyZNMnDgRgAsXLqAoCgUFBbz++usNPa866RhqPnGtX4tKZW21vgAEyYiIwgVFdDL/Ki8tM/HCR8kAvHx/P7q0k4mLhFmtbf5BQUFERESQlJQEQFJSEhEREVZNPgC5ubkYDOabqMnJyRw6dIjY2FjCwsLYtm0bGzduZOPGjTz44IPcc889TZ74ATw05tMrq6Kff2FRGT4tNGjULt3bVQiLWct22jsE4UDq9Dtw+vTpJCQksGDBAvz9/S1dNSdMmMCUKVPo2bMnKSkpzJw5E7VaTWBgIAsXLsTHx6dJg6+Nh8el5G+onPzPnS+muESmvxNCuKc6Jf/w8HBWrFhRafvixYsty4MHD2bw4MG1vtbkyZPrEV7jqFUqPDSqKvv5yxjowtUtSRjKibP5zPh0u2Xbgq/3seNgFle1acnrj0TbMTphby7f5uHpoZaxzoXb6tjWjyUJQ4m+JhSAHQezAPMDYLn5JfYMTdiZyyd/D42myjb/QL8W3NizrR0iEqL5XdMxsNK2Z+f/aodIhKNw+eTv5amuss2/oKgMP18vO0QkRPPLqvDMS0WKPATmtlw++VfV7CPDOQt307nC0M+BFbo+f7/9lD3CEQ7APZL/FTX/bZee7pUpHIW76N21jWX58TujLNf+lxuP8Pn6g+z7M9teoQk7cf3kr9FU6u3zyXd/AHDwZJ49QhLCLuY8NpA7Y66ms86f6Q9dD5h7xG3afZp3/rvXztGJ5ub6yd9TTZmh6v78N0SGNnM0QthPUNLCSwIAABhYSURBVGtvRt94NSqViuAA8zM4FQd+k15x7sX1k7+HmrIq+vkD9Ose0szRCOG4vvzxsL1DEM3I9ZO/xvqG7/jEjXaMRgjHtXHXaXuHIJqRy9/x9PTQVLrhC9DS2+VPXYgaffTMYPQ5hRQWG3j7yz32Dkc0M5fPgJ6eastDXhXb/t+ZfJO9QhLCIbTw0tCprX/tOwqX5PrNPh6XH/IqLDaPOurbwsMy4qcQ4jK56es+XL/mX6HN/+Kl5H//8G72DEkIh9OtfQCHTuUxcc4mALy9NCx4pvaBGoXzcvnqr5enxlLzv1hiTv4tvT3tGZIQDqfivL8AxaVGtqTo7RSNaA4un/zNwzuYu3rqswsBc7OPEOKynuHaStuWrPmDp97/xQ7RiObgJsnfXPNP+u24ZZsQ4rLqZrTLv1jGqQzHnxtX1J/LZ0FPjRqjScFkUsjKM49s2D6k+kmNhXBXw/q1A+DdK3rCPf7WRowmuRHsalw/+XuaR+6sOKa/SqWyVzhCOKz4W7qxJGEo/i0rD3U+ae7P/P5Hhh2iEk3F9ZP/pSYe6cImRN19/OIQPnlxiGXdaFLYkZZpx4iErblP8jeYCA7wlsHchKgDtUpV6RfyjoNZfLAyxU4RCVtz/eR/6WGuMqOJrLxiysrkF4AQdfXy/f2s1ncfPmenSIStuX7yL2/zv9TXf+ehLHuGI4RT6dKudaVtMjiia3D95H+p2Sf/YhkAdw7qbM9whHA637w1utJYWL/ukwfAnJ3bJP/zhaUABFTRk0EIUT2NRk3rll7cd8vlYVE++e4PygxGvt9+ipLSqidLEo7N9ZP/pTb/vIISAHxlaAchGuQv/doxY3x/y/qHX+3nix8P89i8n6udLU84LpdP/l6X2vwvXKr5yzj+QjRcxQckK076/ujcn+0RjmiEOiX/Y8eOMW7cOIYPH864ceM4fvx4pX2ysrJ47LHHGD16NCNHjmTVqlWWsvnz53PbbbcxevRoxowZwy+/NN94IZfb/M3J30fG9RGiURa/cLO9QxA2UKdMOG3aNOLj44mLi2PVqlVMnTqVzz//3GqfxMREoqKi+Oijj8jJyWHMmDH0798fnU5Hr169GD9+PD4+PqSlpXH//fezZcsWvL29m+SkKipP/jn55mYfH6n5C9Eo1Y0DVFxqwEOj5rg+n4LiMnp3adPMkYn6qLXmn52dTWpqKrGxsQDExsaSmppKTk6O1X5paWnExMQAoNVq6dGjB2vXrgUgJiYGHx8fALp3746iKOTl5dn0RKpTPmnL/j/N8XrJoG5CNNq0/7sef19Ppv7fdZZtj8/bzMQ5m5i1bCfv/y+Fd1fslSfrHVitmVCv1xMaGopGY24712g0hISEoNdbd/WKjIxkzZo1KIrCqVOn2L17N2fOnKn0et988w0dOnSgbdu2NjqFmimKYrXeWnr7CNFoHdv68e6UGDq19efJMT2r3CflaDYvLkxu5shEXdmsDSQhIYFZs2YRFxdHWFgYAwYMsHxhlPv999957733WLJkSb1fPyioYSNxmkzWyT8kxHHnLA0O9rN3CHXiDHE6Q4zgGnHe2qYVH361r8qy3PwSjpzNZ0DPsKYKzcIVPsvmVGvy1+l0ZGRkYDQa0Wg0GI1GMjMz0el0VvtptVrmzp1rWZ8wYQJdunSxrO/evZvnn3+eBQsW0Llz/R+0ys4uqJTI6+LKDzoryzHHJg8O9nPY2CpyhjidIUZwrThvG9CR75JPAPD3sb14d8XlMYBmfbqdJQlD7R6jI2jOONVqVY2V5lqbfYKCgoiIiCApKQmApKQkIiIi0GqtZ/7Jzc3FYDBPk5icnMyhQ4cs9wlSUlJ4+umnef/994mMjGzwyQghHNNdg8NZ+OxgZk6Ipld4G+4d2sWqXB4Eczx1uvs5ffp0li1bxvDhw1m2bBkzZswAzLX7ffvMP/dSUlIYNWoUI0aM4P3332fhwoWWm7wzZsyguLiYqVOnEhcXR1xcHAcPHmyiUxJC2IOXpwZdUEsAbu3fgTcn3mApm/PlbnuFJapRpzb/8PBwVqxYUWn74sWLLcuDBw9m8ODBVR6/cuXKBoZnG+1DWnEqs4A5jw20axxCuJNQrS8DIkNJPpDBn2cu2DsccQW36Pc4/aHrWfT8zQS1bvrnCoQQl00YfbmZ94WPfuPE2XxKy6QJyBG4RfJXqVSW/v5CCPs4d76YGZ9uZ9LbMhSEI5CMKIRoUt3aB1TaVj6/hrAfSf5CiCb1/F97V9qW8E95+MveJPkLIZqURq3m1Qevs9oWHuZPSZkRk1L/Z3eEbcgoZ0KIJne1zp8lCUNRFIWHZ//EjoNZ7Dh4ue1/0fM3o1GbJ4y/cuJ40TQk+Qshmk11iX3inE2oVKAo8MmLQ+QLoBlIs48QwiGUtwA9M/9X+wbiJiT5CyGa1awKT/5W5XxBaTNF4t6k2UcI0azaan35+IUhZOReBCD5wFmSfjthtU9BURnZ54vp2NYxRsB0RZL8hRDNTq1WWcYBOnq68tAPU94zT/X6z+dutszGJ2xLPlUhhF2NuzQCqC7Il78N725VtmLTEcA8KuiVEzOJxpGavxDCrjqE+vHu5Jto5ePJoVPW07v+sCOdVt6efLPlGABvPBJNWJuW9gjT5UjNXwhhd/4tvVCrVbQJqDz4YnniB3jl423NGZZLk+QvhHAYbVr78PYTNzL3cRl+valJ8hdCOJRAvxZo/b1JnDSgyvKtB842c0SuSZK/EMIhhQT4WL4ArqrQzr9odWq1x3zxw2H02YVNHpsrkOQvhHBYIQE+LEkYyuuPRNPKx9Oyfcan2632M5oUkn47zvc7TvGPxeb7AtnnizEYZejo6kjyF0I4hfefiiH40g3hE2fzGZ+40VL22idb+Wrzn5b18Ykbef6j35g4Z1Nzh+k0JPkLIZzGQyMjrNZPZuTz7+8PsSsts9pjln9/qKnDckqS/IUQTqNHx0Cu6x5sWZ++dDs/7Eyv8Zgfd6ZTUFQGQHpWAReLDU0ao7OQh7yEEE7l0bhI+vyRyeIabvxeqXy4iHJLEobaOiynIzV/IYRT0ajV9O0aXGn7wKi2DIxqi9a/BQAjoztU+xpGk9wIlpq/EMLpeHla11t7hrfhkdhrADApCumZBXQI9aN/RGilnkEAv+0/S0yvMMv6z3tO813yCd56zH0eLpOavxDC6ahUKvp1u1z7f/C2yzeC1SoVHULNQ0F3bOvHR88MZuyQcKvjl65JY8P2U4xP3EhRiYHP1h3k3Pli/u1GN4el5i+EcEqP3xnFvj9z6BUeRHCwH1lZ+VXu18JLw8jojoyM7sjR0+eZ+a+dAHz542EANu0+bdn3h53pxN/SremDdwB1qvkfO3aMcePGMXz4cMaNG8fx48cr7ZOVlcVjjz3G6NGjGTlyJKtWrbKUGY1GZsyYwbBhw7jllltYsWKFzU5ACOGeVCoVvcKD6nVM+FWtK21bsemorUJyKnVK/tOmTSM+Pp7169cTHx/P1KlTK+2TmJhIVFQUq1evZvny5bzzzjvo9XoAVq9ezcmTJ9mwYQP/+c9/+OCDD0hPr7l7lhBCNIX73KRmX5tak392djapqanExsYCEBsbS2pqKjk5OVb7paWlERMTA4BWq6VHjx6sXbsWgDVr1jB27FjUajVarZZhw4axbt06W5+LEELUqnOYf43l7jIkRK3JX6/XExoaikajAUCj0RASEmKp1ZeLjIxkzZo1KIrCqVOn2L17N2fOnLG8RljY5TvrOp2Os2dlZD4hRPPrdGle4I6hfrz+cH/LTeGHRvYAIPtCMWDuNeTKbHbDNyEhgVmzZhEXF0dYWBgDBgywfGHYQlBQqwYfGxzsHJNAS5y24wwxgsRpS/WJcfXbcZblD6/RAbD98DkADpzIIzTYn0fe2ABASKAPmblFfD59OIF+lSebaco4m1KtyV+n05GRkYHRaESj0WA0GsnMzESn01ntp9VqmTt3rmV9woQJdOnSxfIaZ86coVevXkDlXwJ1kZ1dgMlU/2/imnoBOBKJ03acIUaQOG3JFjF2aOMLwPHT5/nReNyyPTO3CIDDx7K5Wldzk1FtmvOzVKtVNVaaa232CQoKIiIigqSkJACSkpKIiIhAq9Va7Zebm4vBYB4zIzk5mUOHDlnuE4wYMYIVK1ZgMpnIycnhhx9+YPjw4Q0+KSGEsLXyuYG37NNXORjc+cLS5g6pSdWp2Wf69OkkJCSwYMEC/P39mT17NmCu3U+ZMoWePXuSkpLCzJkzUavVBAYGsnDhQnx8fACIi4tj79693HrrrQA88cQTtG/fvolOSQgh6s9DU3NdON8dk394eHiVffMXL15sWR48eDCDBw+u8niNRsOMGTMaGKIQQthf+Y1gVyHDOwghRBU++Lu56/qQPlcB8O2vx+0Yje3J8A5CCHHJmxNv4NCpPGKuNXdIKR/6+adLQ0AcOJ5DZCdttcc7E6n5CyHEJaFaX0vir8rOGmYMczaS/IUQohbvP2VuAtq05wylZUY7R2MbkvyFEKIWrXw8LcuT3v7ZjpHYjiR/IYSoB/+WXoxP3Mj4xI32DqVRJPkLIUQdhGrNTwBfcJH+/pL8hRCiDt6ceEOlbf9af9AOkdiGdPUUQogG+mn3aTrp/DAaFW6+9DyAs5CavxBC1FFLb3N9ecygzpZtS9ek8fn6g/z+R4a9wmoQqfkLIUQdffD3QZblrzb/aVW2cNUB+keENndIDSY1fyGEsBGjyTwLWFGJgT/PXHDoCWEk+QshRANMf+j6StvSTuZRZjDyxDubeePzHVUODe0oJPkLIUQDdAj14/VHoln8ws1o1CoA3v5yDy98lGzZZ9Ou0/YKr1aS/IUQooGuatMSjVrN/Kcv3wuoOOnL8OgO9girTiT5CyFEI3l5Vj1f+bptJwE4euY8n69LQ3GgewCS/IUQogll5Fxk5uc72bTnDIVFZfYOx0KSvxBC2MATd/a0LD91dy/L8omMyxO2r00+3owR1UySvxBC2EC/7sE8EhsBQLf2AZYvg4WrDlj2+XzNH3aJrSrykJcQQtjIwCgdA6N0APTt1qbKfS4WG/D1tn/qtX8EQgjhglQqVZXbn3x3M/26B7PzYBaP3h5J9DX2eSpYmn2EEKKJjB0SDsDHLwyx2r7zYBYA//z2QKVjmovU/IUQoomMjO7IyOiOAPTp2obdh89ZlXcIaWWPsACp+QshRLOYfFevSttOZhZgMJrsEI3U/IUQotn8LzGWP45kkZdfwtv/2QPA3iPn6Nc9pNljkZq/EEI0kxaeGq5q05LIq7WWbf/ZeMQusUjyF0IIO3j/qRgAzp0vZnziRg4cz8FkUizDQje1OjX7HDt2jISEBPLy8ggICGD27Nl06tTJap/s7Gxeeukl9Ho9BoOB6OhoXnnlFTw8PGosE0IId9Tyir7+b3+5x7K8JGFok79/nWr+06ZNIz4+nvXr1xMfH8/UqVMr7bNw4ULCw8NZvXo13377LQcOHGDDhg21lgkhhDuq7jkAMA8E19RqTf7Z2dmkpqYSGxsLQGxsLKmpqeTk5Fjtp1KpKCwsxGQyUVpaSllZGaGhobWWCSGEu4od2LHK7YtXpzb5e9fa7qLX6wkNDUWjMQ9ZqtFoCAkJQa/Xo9Vevmnx+OOPM3nyZG666SaKioq477776NevX61ldRUU1PD+sMHBfg0+tjlJnLbjDDGCxGlLzhAjWMf56F29uefWHvi3bMHKjYf5Zc9pjusvkJlbhE8rb1r5eDZZHDZrdF+3bh3du3fns88+o7CwkAkTJrBu3TpGjBhRY1ldZWcXYDLVfyzs4GA/srLya9/RziRO23GGGEHitCVniBGqjzOnuIwh1+oYcq2O8YkbAfh11yn6dgtu8Hup1aoaK821NvvodDoyMjIwGo0AGI1GMjMz0el0VvstW7aM22+/HbVajZ+fH0OHDmXbtm21lgkhhKjswPGc2ndqhFqTf1BQEBERESQlJQGQlJRERESEVZMPQLt27di8eTMApaWlJCcn07Vr11rLhBBCXFbeBfSnXacpaMLJX+rU22f69OksW7aM4cOHs2zZMmbMmAHAhAkT2LdvHwAvv/wyO3fuZPTo0dxxxx106tSJe+65p9YyIYQQl1Vs55/y3i9N9j51avMPDw9nxYoVlbYvXrzYstyhQweWLl1a5fE1lQkhhKieyaSgVlffLbSh5AlfIYRwMPcM6WJZzrlQ3CTvIclfCCEczIjoDoQG+gDw894zTfIekvyFEMIBlc8BnJlb1CSvL4PrCCGEA2oX0oq5jw8kwK9Fk7y+JH8hhHBQWn/vJnttafYRQgg3JMlfCCHckCR/IYRwQ5L8hRDCDUnyF0IINyTJXwgh3JDTdPVszNgWTTEuRlOQOG3HGWIEidOWnCFGaL44a3sflaIo9Z8hRQghhFOTZh8hhHBDkvyFEMINSfIXQgg3JMlfCCHckCR/IYRwQ5L8hRDCDUnyF0IINyTJXwgh3JAkfyGEcENOM7xDQxw7doyEhATy8vIICAhg9uzZdOrUqcnfNzc3lxdeeIGTJ0/i5eVFx44dee2119BqtXTv3p1u3bqhVpu/d9966y26d+8OwMaNG3nrrbcwGo1ERkby5ptv4uPjU2tZYwwdOhQvLy9atDBPFffcc88RExPDnj17mDp1KiUlJVx11VXMmTOHoKAggAaXNUR6ejpPPPGEZT0/P5+CggJ+//33amNvrhhnz57N+vXrOX36NKtXr6Zbt25AzdddU5TVN8aark/ALtdodZ9lU/yNG/P3ryrOmq7RpjoHm1Bc2AMPPKB88803iqIoyjfffKM88MADzfK+ubm5ytatWy3riYmJyksvvaQoiqJ069ZNKSgoqHRMQUGBMnDgQOXYsWOKoijKyy+/rHzwwQe1ljXWkCFDlIMHD1ptMxqNyrBhw5Tt27criqIo8+fPVxISEhpVZitvvPGGMmPGjGpjb84Yt2/frpw5c6ZSHDVdd01RVt8Ya7o+FcU+12h1n6Wt/8aN/ftXF2dFFa/RpjgHW3HZ5H/u3DmlX79+isFgUBRFUQwGg9KvXz8lOzu72WNZt26d8uCDDyqKUv1/rDVr1igTJ060rKekpCijRo2qtayxqrow9+7dq9x2222W9ezsbKV3796NKrOFkpISJTo6Wtm/f3+1sdsjxopx1HTdNUVZQ2K8UsXrU1Hse43WNfnb+xqtLq4rr9GmOAdbcdlmH71eT2hoKBqNBgCNRkNISAh6vd7y87Y5mEwmvvjiC4YOHWrZ9sADD2A0Ghk0aBCTJ0/Gy8sLvV5PWFiYZZ+wsDD0er3lXKors4XnnnsORVHo168fzzzzTKX302q1mEwm8vLyGlwWEBDQ6Dg3btxIaGgokZGR1cbu7+9v1xhruu4URbF5WWOv5aquT3Csa9SWf2N7XKO2PgdbxAlyw7fJvf766/j6+nL//fcDsGnTJr766iuWL1/OkSNHmD9/vl3jW758Od9++y0rV65EURRee+01u8ZTk5UrV3LXXXdZ1p0pdkd15fUJjnWNOtvf+MprFBz3HFw2+et0OjIyMjAajQAYjUYyMzPR6XTNFsPs2bM5ceIE7777ruXmWfn7t2rVirFjx7Jr1y7L9jNnzliOPXPmjGXfmsoaq/x1vLy8iI+PZ9euXZXeLycnB7VaTUBAQIPLGisjI4Pt27czevToGmMv326PGMvfu7rrrinKGqOq67P8HMAxrlFb/42b+xptinOwFZdN/kFBQURERJCUlARAUlISERERzdbkM2/ePPbv38/8+fPx8vIC4Pz58xQXFwNgMBhYv349ERERAMTExLBv3z6OHz8OwJdffsnIkSNrLWuMixcvkp+fD4CiKKxZs4aIiAiioqIoLi5mx44dlvcbMWIEQIPLGuvrr79m8ODBBAYG1hi7PWOEmq+7pihrqKquT3Csa7Qp/sbNeY021TnYiktP5nL06FESEhK4cOEC/v7+zJ49m86dOzf5+x4+fJjY2Fg6deqEt7c3AO3ateORRx5h6tSpqFQqDAYDffr04eWXX6Zly5YA/PDDD8yZMweTyURERASJiYn4+vrWWtZQp06dYvLkyRiNRkwmE+Hh4bzyyiuEhISwa9cupk2bZtXNrE2bNgANLmuM4cOH849//INBgwbVGntzxfjGG2+wYcMGzp07R2BgIAEBAXz33Xc1XndNUVbfGN99990qr8/58+eze/duu1yjVcW5cOHCJvkbN+bvX93fHCpfo+AY12l1XDr5CyGEqJrLNvsIIYSoniR/IYRwQ5L8hRDCDUnyF0IINyTJXwgh3JAkfyGEcEOS/IUQwg1J8hdCCDf0/7ZKcoSNVtE5AAAAAElFTkSuQmCC\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"loss_labels = ['total loss', 'reward loss', 'td loss']\n", | |
"window = 100 // 2\n", | |
"for i, label in enumerate(loss_labels):\n", | |
" plt.title(label)\n", | |
" smoothed_loss = [np.mean(loss_hist[i][j-window:j+window]) \\\n", | |
" for j in range(window, len(loss_hist[i])-window)]\n", | |
" plt.plot(smoothed_loss, label=label)\n", | |
" plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 302 | |
}, | |
"id": "wdACSHErVls9", | |
"outputId": "4f967bac-f4cc-4809-c1d0-7d4f5fc6901a" | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Text(0.5, 1.0, 'extrinsic reward')" | |
] | |
}, | |
"execution_count": 44, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"window = 20 // 2\n", | |
"smoothed_rewards = [np.mean(all_rewards[i-window:i+window]) \\\n", | |
" for i in range(window, len(all_rewards)-window)]\n", | |
"plt.plot(smoothed_rewards)\n", | |
"plt.title('extrinsic reward')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "Ac3yXTB3lzJg" | |
}, | |
"outputs": [], | |
"source": [ | |
"def fit_skill_vector(skill_vector, batch_data, optimizer, model):\n", | |
" device = next(model.parameters()).device\n", | |
"\n", | |
" batch_data = np.array(batch_data)\n", | |
" obs, _, _, rewards, _, _ = \\\n", | |
" [torch.from_numpy(np.stack(batch_data[:, i])).to(device) \\\n", | |
" for i in range(batch_data.shape[1])]\n", | |
"\n", | |
" batch_svs = einops.repeat(skill_vector, 'e -> b e', b=obs.shape[0])\n", | |
" batch_svs = batch_svs.to(device)\n", | |
" with torch.no_grad():\n", | |
" features, _ = model(obs, batch_svs)\n", | |
"\n", | |
" pred_rewards = torch.bmm(features.unsqueeze(1), batch_svs.unsqueeze(2))\n", | |
" pred_rewards = pred_rewards.squeeze()\n", | |
"\n", | |
" residuals = (rewards - pred_rewards) ** 2\n", | |
" loss = residuals.mean()\n", | |
"\n", | |
" optimizer.zero_grad()\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
"\n", | |
" return loss.item()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "6dX5_Ojpvq9n" | |
}, | |
"outputs": [], | |
"source": [ | |
"fit_steps = 500\n", | |
"ds_batch_size = 256\n", | |
"ds_print_freq = 50\n", | |
"test_episodes = 100\n", | |
"\n", | |
"ds_lr = 1e-3\n", | |
"\n", | |
"skill_vector = sample_skill()\n", | |
"skill_vector.requires_grad = True\n", | |
"\n", | |
"sv_optimizer = torch.optim.Adam((skill_vector,), lr=ds_lr)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 759 | |
}, | |
"id": "O_tqI1tlv5Ya", | |
"outputId": "7b5d2687-55be-4405-c685-0d6ad9191506" | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:4: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", | |
" after removing the cwd from sys.path.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"SV Reward Loss: 0.0172\n", | |
"SV Reward Loss: 0.0030\n", | |
"SV Reward Loss: 0.0011\n", | |
"SV Reward Loss: 0.0010\n", | |
"SV Reward Loss: 0.0008\n", | |
"SV Reward Loss: 0.0014\n", | |
"SV Reward Loss: 0.0015\n", | |
"SV Reward Loss: 0.0008\n", | |
"SV Reward Loss: 0.0008\n", | |
"SV Reward Loss: 0.0011\n", | |
"SV Reward Loss: 0.0014\n", | |
"SV Reward Loss: 0.0016\n", | |
"SV Reward Loss: 0.0008\n", | |
"SV Reward Loss: 0.0012\n", | |
"SV Reward Loss: 0.0008\n", | |
"SV Reward Loss: 0.0011\n", | |
"SV Reward Loss: 0.0009\n" | |
] | |
}, | |
{ | |
"ename": "KeyboardInterrupt", | |
"evalue": "ignored", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-47-65009799b424>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0msv_losses\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mstep_idx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfit_steps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mbatch_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msample_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexp_buffer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mds_batch_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfit_skill_vector\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mskill_vector\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_data\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msv_optimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0msv_losses\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m<ipython-input-24-7a61dba9e776>\u001b[0m in \u001b[0;36msample_batch\u001b[0;34m(buffer, n)\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0msample_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mdata_idxs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchoice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msize\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreplace\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0mbatch_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata_idxs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mKeyboardInterrupt\u001b[0m: " | |
] | |
} | |
], | |
"source": [ | |
"sv_losses = []\n", | |
"for step_idx in range(fit_steps):\n", | |
" batch_data = sample_batch(exp_buffer, ds_batch_size)\n", | |
" loss = fit_skill_vector(skill_vector, batch_data, sv_optimizer, model)\n", | |
" sv_losses.append(loss)\n", | |
"\n", | |
" if step_idx != 0 and step_idx % ds_print_freq == 0:\n", | |
" print('SV Reward Loss: {:.4f}'.format(np.mean(sv_losses[-ds_print_freq:])))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 68, | |
"referenced_widgets": [ | |
"b9abaa3d1e714738988eca0eaa59b4f3", | |
"6a3600d8cd21492bacaa5c1a2d6e83b3", | |
"f90af20614914c8f994619174be2b4c8", | |
"07c4759dcb534616a0c497c6f979607d", | |
"f20aa911d3fe489791a6e3006ac28a38", | |
"44735900376d48258c7a5f035f93685f", | |
"1e1a06065e9349039d53bb167a19fd19", | |
"4e2ecd91fbcb4a199ed38751743c8460", | |
"c570c2d58b974e688b6059d6b31a043c", | |
"4221d18fb26d4bdbb4b2f729beea8a03", | |
"eac0871963cd4420a923e77c785f64f4" | |
] | |
}, | |
"id": "WFPJy1oRxd8a", | |
"outputId": "ceaf2f7d-a90c-42af-83aa-5742035cf6c6" | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "b9abaa3d1e714738988eca0eaa59b4f3", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/100 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Test Avg. Reward: 0.13\n" | |
] | |
} | |
], | |
"source": [ | |
"test_rewards = []\n", | |
"for episode_idx in tqdm(range(test_episodes)):\n", | |
" obs = env.reset()\n", | |
"\n", | |
" ep_rewards = []\n", | |
" done = False\n", | |
" while not done:\n", | |
" # Sample an action\n", | |
" if np.random.rand() < end_act_epsilon:\n", | |
" act = env.action_space.sample()\n", | |
" else:\n", | |
" with torch.no_grad():\n", | |
" _, sfs = model(obs.unsqueeze(0).to(DEVICE),\n", | |
" skill_vector.unsqueeze(0).to(DEVICE))\n", | |
" sfs = sfs.cpu()\n", | |
" q_vals = torch.matmul(sfs[0], skill_vector.unsqueeze(1))\n", | |
" act = torch.argmax(q_vals).item()\n", | |
"\n", | |
" # Make a step\n", | |
" next_obs, reward, done, _ = env.step(act)\n", | |
" reward = np.float32(reward)\n", | |
" ep_rewards.append(reward)\n", | |
" obs = next_obs\n", | |
"\n", | |
" test_rewards.append(sum(ep_rewards))\n", | |
"\n", | |
"print('Test Avg. Reward: {:.2f}'.format(np.mean(test_rewards)))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "9G4YTac3DwRG" | |
}, | |
"outputs": [], | |
"source": [ | |
"# ds_fit_episodes = 100\n", | |
"# ds_print_freq = 2500\n", | |
"\n", | |
"# ds_lr = 0.1\n", | |
"\n", | |
"# skill_vector = sample_skill()\n", | |
"# skill_vector.requires_grad = True\n", | |
"\n", | |
"# sv_optimizer = torch.optim.Adam((skill_vector,), lr=ds_lr)\n", | |
"\n", | |
"# ds_all_rewards = []\n", | |
"# step_idx = 0\n", | |
"# for episode_idx in range(ds_fit_episodes):\n", | |
"# obs = env.reset()\n", | |
"\n", | |
"# ep_rewards = []\n", | |
"# batch_data = []\n", | |
"# done = False\n", | |
"# while not done:\n", | |
"# # Sample an action\n", | |
"# if np.random.rand() < end_act_epsilon:\n", | |
"# act = env.action_space.sample()\n", | |
"# else:\n", | |
"# with torch.no_grad():\n", | |
"# features, sfs = model(obs.unsqueeze(0).to(DEVICE),\n", | |
"# skill_vector.unsqueeze(0).to(DEVICE))\n", | |
"# sfs = sfs.cpu()\n", | |
"# q_vals = torch.matmul(sfs[0], skill_vector.unsqueeze(1))\n", | |
"# act = torch.argmax(q_vals).item()\n", | |
"\n", | |
"# # Make a step\n", | |
"# next_obs, reward, done, _ = env.step(act)\n", | |
"# reward = np.float32(reward)\n", | |
"# batch_data.append([features.detach().cpu().numpy(), reward])\n", | |
"# ep_rewards.append(reward)\n", | |
"# obs = next_obs\n", | |
"\n", | |
"# if step_idx % ds_print_freq == 0:\n", | |
"# loss, r_2 = fit_skill_vector(skill_vector, batch_data, sv_optimizer)\n", | |
"# print('Step: {}\\t\\tReward Fit Loss: {:.4f}\\tr^2: {:.4f}'\n", | |
"# .format(step_idx, loss, r_2))\n", | |
"# batch_data = []\n", | |
"\n", | |
"# step_idx += 1\n", | |
"# ds_all_rewards.append(sum(ep_rewards))\n", | |
"# if episode_idx % 10 == 0:\n", | |
"# print('Avg Reward: {:.1f}'.format(np.mean(ds_all_rewards[-10:])))\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "PeCUpqlbuT4D" | |
}, | |
"outputs": [], | |
"source": [ | |
"" | |
] | |
} | |
], | |
"metadata": { | |
"accelerator": "GPU", | |
"colab": { | |
"collapsed_sections": [], | |
"name": "VISR.ipynb", | |
"provenance": [], | |
"authorship_tag": "ABX9TyPUL/Y6PHTPv8KB7tLzGzom", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"name": "python3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"widgets": { | |
"application/vnd.jupyter.widget-state+json": { | |
"07c4759dcb534616a0c497c6f979607d": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.5.0", | |
"model_name": "HTMLModel", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_4221d18fb26d4bdbb4b2f729beea8a03", | |
"placeholder": "", | |
"style": "IPY_MODEL_eac0871963cd4420a923e77c785f64f4", | |
"value": " 100/100 [00:18<00:00, 5.00it/s]" | |
} | |
}, | |
"1e1a06065e9349039d53bb167a19fd19": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.5.0", | |
"model_name": "DescriptionStyleModel", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"4221d18fb26d4bdbb4b2f729beea8a03": { | |
"model_module": "@jupyter-widgets/base", | |
"model_module_version": "1.2.0", | |
"model_name": "LayoutModel", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"44735900376d48258c7a5f035f93685f": { | |
"model_module": "@jupyter-widgets/base", | |
"model_module_version": "1.2.0", | |
"model_name": "LayoutModel", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"4e2ecd91fbcb4a199ed38751743c8460": { | |
"model_module": "@jupyter-widgets/base", | |
"model_module_version": "1.2.0", | |
"model_name": "LayoutModel", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"6a3600d8cd21492bacaa5c1a2d6e83b3": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.5.0", | |
"model_name": "HTMLModel", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_44735900376d48258c7a5f035f93685f", | |
"placeholder": "", | |
"style": "IPY_MODEL_1e1a06065e9349039d53bb167a19fd19", | |
"value": "100%" | |
} | |
}, | |
"b9abaa3d1e714738988eca0eaa59b4f3": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.5.0", | |
"model_name": "HBoxModel", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HBoxModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HBoxView", | |
"box_style": "", | |
"children": [ | |
"IPY_MODEL_6a3600d8cd21492bacaa5c1a2d6e83b3", | |
"IPY_MODEL_f90af20614914c8f994619174be2b4c8", | |
"IPY_MODEL_07c4759dcb534616a0c497c6f979607d" | |
], | |
"layout": "IPY_MODEL_f20aa911d3fe489791a6e3006ac28a38" | |
} | |
}, | |
"c570c2d58b974e688b6059d6b31a043c": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.5.0", | |
"model_name": "ProgressStyleModel", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "ProgressStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"bar_color": null, | |
"description_width": "" | |
} | |
}, | |
"eac0871963cd4420a923e77c785f64f4": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.5.0", | |
"model_name": "DescriptionStyleModel", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"f20aa911d3fe489791a6e3006ac28a38": { | |
"model_module": "@jupyter-widgets/base", | |
"model_module_version": "1.2.0", | |
"model_name": "LayoutModel", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"f90af20614914c8f994619174be2b4c8": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.5.0", | |
"model_name": "FloatProgressModel", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "FloatProgressModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "ProgressView", | |
"bar_style": "success", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_4e2ecd91fbcb4a199ed38751743c8460", | |
"max": 100, | |
"min": 0, | |
"orientation": "horizontal", | |
"style": "IPY_MODEL_c570c2d58b974e688b6059d6b31a043c", | |
"value": 100 | |
} | |
} | |
} | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment