Created
July 25, 2022 08:59
-
-
Save enakai00/1407e48e0ce1607d48bb6a8906a8e2bc to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "2b31bfe9-3b3d-4313-9435-cead3071ee41", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import gym\n", | |
"import numpy as np\n", | |
"import copy, random, time, subprocess, os\n", | |
"from tensorflow.keras import layers, models" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "6dcb0907-a790-47b0-8ac6-a6c2eb912d61", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class QValue:\n", | |
" def __init__(self):\n", | |
" self.model = None\n", | |
"\n", | |
" def get_action(self, state):\n", | |
" states = []\n", | |
" actions = []\n", | |
" for a in range(5):\n", | |
" states.append(np.array(state))\n", | |
" action_onehot = np.zeros(5)\n", | |
" action_onehot[a] = 1\n", | |
" actions.append(action_onehot)\n", | |
" \n", | |
" q_values = self.model.predict([np.array(states), np.array(actions)])\n", | |
" optimal_action = np.argmax(q_values)\n", | |
" return optimal_action, q_values[optimal_action][0]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "cdc3cc3a-04dc-463b-87b0-5669d962e2fa", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def join_frames(o0, o1):\n", | |
" return np.r_[o0.transpose(), o1.transpose()].transpose() " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "14a92b6f-27a3-4bcf-b3e1-3e9c0dc85389", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"q_value = QValue()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "c9364323-e246-4abd-bfbc-74abe0f2d5e7", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Copying gs://etsuji-car-racing-v2-model01/model01/car-racing-v2-model01-104.hd5...\n", | |
"- [1 files][292.5 MiB/292.5 MiB] \n", | |
"Operation completed over 1 objects/292.5 MiB. \n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"load model car-racing-v2-model01-104.hd5\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/opt/conda/lib/python3.7/site-packages/gym/core.py:330: DeprecationWarning: \u001b[33mWARN: Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.\u001b[0m\n", | |
" \"Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.\"\n", | |
"/opt/conda/lib/python3.7/site-packages/gym/wrappers/step_api_compatibility.py:40: DeprecationWarning: \u001b[33mWARN: Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.\u001b[0m\n", | |
" \"Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.\"\n", | |
"/opt/conda/lib/python3.7/site-packages/gym/core.py:52: DeprecationWarning: \u001b[33mWARN: The argument mode in render method is deprecated; use render_mode during environment initialization instead.\n", | |
"See here for more information: https://www.gymlibrary.ml/content/api/\u001b[0m\n", | |
" \"The argument mode in render method is deprecated; \"\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"3 11.814814814814826\n", | |
"3 27.333333333333314\n", | |
"3 46.555555555555486\n", | |
"3 65.77777777777771\n", | |
"3 62.77777777777779\n", | |
"3 59.77777777777775\n", | |
"3 56.77777777777771\n", | |
"3 53.777777777777665\n", | |
"3 50.77777777777762\n", | |
"3 47.77777777777758\n", | |
"3 44.77777777777754\n" | |
] | |
} | |
], | |
"source": [ | |
"import datetime \n", | |
"import imageio\n", | |
"\n", | |
"checkpoint = 104\n", | |
"model = 'model01'\n", | |
"\n", | |
"BUCKET = 'gs://etsuji-car-racing-v2-{}'.format(model)\n", | |
"filename = 'car-racing-v2-{}-{}.hd5'.format(model, checkpoint)\n", | |
"subprocess.run(['gsutil', 'cp', '{}/{}/{}'.format(BUCKET, model, filename), './'])\n", | |
"print('load model {}'.format(filename))\n", | |
"q_value.model = models.load_model(filename)\n", | |
"os.remove(filename)\n", | |
"\n", | |
"env = gym.make(\"CarRacing-v2\", continuous=False)\n", | |
"o0 = env.reset()\n", | |
"o1 = copy.deepcopy(o0)\n", | |
"done = 0\n", | |
"total_r = 0\n", | |
"c = 0\n", | |
"\n", | |
"frames = []\n", | |
"\n", | |
"while not done: \n", | |
" a, _ = q_value.get_action(join_frames(o0, o1))\n", | |
" o_new, r, done, i = env.step(a)\n", | |
" total_r += r\n", | |
" o0, o1 = o1, o_new \n", | |
" c += 1\n", | |
" frame = env.render('rgb_array')\n", | |
" frames.append(frame) \n", | |
" if c % 30 == 0:\n", | |
" print(a, total_r)\n", | |
"\n", | |
"now = datetime.datetime.now()\n", | |
"imageio.mimsave('car-racing-v2-{}-{}-{}-{}.gif'.format(model, int(total_r), checkpoint, now.strftime('%Y%m%d-%H%M%S')),\n", | |
" frames, 'GIF' , **{'duration': 1.0/30.0})" | |
] | |
} | |
], | |
"metadata": { | |
"environment": { | |
"kernel": "python3", | |
"name": "tf2-gpu.2-8.m94", | |
"type": "gcloud", | |
"uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-8:m94" | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.12" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment