Last active
January 19, 2025 12:47
-
-
Save ValentinFunk/f8db747ab6ad8782b48aba0d0db6f565 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "bd0b6703-2f23-4f57-83aa-f6eb323e6958", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2025-01-19T11:19:59.565050Z", | |
"iopub.status.busy": "2025-01-19T11:19:59.563942Z", | |
"iopub.status.idle": "2025-01-19T11:19:59.651501Z", | |
"shell.execute_reply": "2025-01-19T11:19:59.650376Z", | |
"shell.execute_reply.started": "2025-01-19T11:19:59.565050Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<pyvirtualdisplay.display.Display at 0x7f9c60386f10>" | |
] | |
}, | |
"execution_count": 1, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Virtual display\n", | |
"from pyvirtualdisplay import Display\n", | |
"\n", | |
"virtual_display = Display(visible=0, size=(1400, 900))\n", | |
"virtual_display.start()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "9f9bf992-5c21-4c05-a87e-69e1dcabcec6", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch.nn as nn\n", | |
"\n", | |
"class Net(nn.Module):\n", | |
" def __init__(self, obs_size, actions_size, hidden_layers):\n", | |
" super(Net, self).__init__()\n", | |
" \n", | |
" self.model = nn.Sequential(\n", | |
" nn.Linear(obs_size, hidden_layers),\n", | |
" nn.ReLU(),\n", | |
" nn.Linear(hidden_layers, actions_size)\n", | |
" )\n", | |
"\n", | |
" def forward(self, x):\n", | |
" return self.model(x)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"id": "2fe6e83c-2a5e-4670-afd6-e6ef43ba094e", | |
"metadata": { | |
"execution": { | |
"iopub.execute_input": "2025-01-19T11:22:38.199101Z", | |
"iopub.status.busy": "2025-01-19T11:22:38.197219Z", | |
"iopub.status.idle": "2025-01-19T11:22:38.214770Z", | |
"shell.execute_reply": "2025-01-19T11:22:38.214093Z", | |
"shell.execute_reply.started": "2025-01-19T11:22:38.198884Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Played 100 episodes. Median reward: 20.0, Best reward: 20.0, Worst reward: 10.0\n", | |
"Selected 30 best episodes\n", | |
"Rewards: 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0\n", | |
"Iteration 0 - loss: 0.7031133770942688\n", | |
"Played 100 episodes. Median reward: 15.5, Best reward: 20.0, Worst reward: 9.0\n", | |
"Selected 30 best episodes\n", | |
"Rewards: 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 19.0, 19.0\n", | |
"Iteration 1 - loss: 0.654332160949707\n", | |
"Played 100 episodes. Median reward: 14.0, Best reward: 20.0, Worst reward: 8.0\n", | |
"Selected 30 best episodes\n", | |
"Rewards: 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 19.0, 19.0, 19.0, 18.0, 18.0, 18.0, 18.0, 18.0, 18.0, 17.0, 17.0, 16.0\n", | |
"Iteration 2 - loss: 0.5910254120826721\n", | |
"Played 100 episodes. Median reward: 12.0, Best reward: 20.0, Worst reward: 8.0\n", | |
"Selected 30 best episodes\n", | |
"Rewards: 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 19.0, 19.0, 19.0, 19.0, 19.0, 17.0, 17.0, 17.0, 16.0, 16.0, 16.0, 16.0, 16.0, 15.0, 15.0, 15.0, 15.0, 15.0, 14.0, 14.0\n", | |
"Iteration 3 - loss: 0.5180516242980957\n", | |
"Played 100 episodes. Median reward: 12.0, Best reward: 20.0, Worst reward: 8.0\n", | |
"Selected 30 best episodes\n", | |
"Rewards: 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 19.0, 19.0, 19.0, 19.0, 18.0, 18.0, 17.0, 17.0, 17.0, 17.0, 17.0, 17.0, 17.0, 17.0, 16.0, 16.0, 16.0, 16.0, 16.0, 15.0, 15.0, 15.0\n", | |
"Iteration 4 - loss: 0.4938879609107971\n", | |
"Played 100 episodes. Median reward: 11.5, Best reward: 20.0, Worst reward: 8.0\n", | |
"Selected 30 best episodes\n", | |
"Rewards: 20.0, 20.0, 20.0, 20.0, 20.0, 19.0, 18.0, 17.0, 16.0, 16.0, 16.0, 15.0, 15.0, 15.0, 15.0, 15.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0\n", | |
"Iteration 5 - loss: 0.38373544812202454\n", | |
"Played 100 episodes. Median reward: 11.0, Best reward: 20.0, Worst reward: 8.0\n", | |
"Selected 30 best episodes\n", | |
"Rewards: 20.0, 20.0, 18.0, 18.0, 18.0, 17.0, 16.0, 16.0, 16.0, 16.0, 16.0, 15.0, 15.0, 15.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0\n", | |
"Iteration 6 - loss: 0.39132174849510193\n", | |
"Played 100 episodes. Median reward: 11.0, Best reward: 20.0, Worst reward: 8.0\n", | |
"Selected 30 best episodes\n", | |
"Rewards: 20.0, 20.0, 20.0, 18.0, 18.0, 17.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 15.0, 15.0, 15.0, 15.0, 15.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0\n", | |
"Iteration 7 - loss: 0.35876840353012085\n", | |
"Played 100 episodes. Median reward: 10.0, Best reward: 20.0, Worst reward: 8.0\n", | |
"Selected 30 best episodes\n", | |
"Rewards: 20.0, 18.0, 18.0, 18.0, 15.0, 15.0, 15.0, 14.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0\n", | |
"Iteration 8 - loss: 0.3280767500400543\n", | |
"Played 100 episodes. Median reward: 10.0, Best reward: 20.0, Worst reward: 8.0\n", | |
"Selected 30 best episodes\n", | |
"Rewards: 20.0, 18.0, 16.0, 15.0, 15.0, 15.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0\n", | |
"Iteration 9 - loss: 0.3035016357898712\n", | |
"Played 100 episodes. Median reward: 10.0, Best reward: 20.0, Worst reward: 8.0\n", | |
"Selected 30 best episodes\n", | |
"Rewards: 20.0, 20.0, 19.0, 16.0, 15.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0, 11.0\n", | |
"Iteration 10 - loss: 0.23971830308437347\n", | |
"Played 100 episodes. Median reward: 10.0, Best reward: 20.0, Worst reward: 8.0\n", | |
"Selected 30 best episodes\n", | |
"Rewards: 20.0, 19.0, 18.0, 16.0, 16.0, 15.0, 15.0, 15.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0\n", | |
"Iteration 11 - loss: 0.36385658383369446\n", | |
"Played 100 episodes. Median reward: 10.0, Best reward: 16.0, Worst reward: 8.0\n", | |
"Selected 30 best episodes\n", | |
"Rewards: 16.0, 15.0, 15.0, 14.0, 14.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0\n", | |
"Iteration 12 - loss: 0.24887436628341675\n", | |
"Played 100 episodes. Median reward: 10.0, Best reward: 16.0, Worst reward: 8.0\n", | |
"Selected 30 best episodes\n", | |
"Rewards: 16.0, 16.0, 15.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0\n", | |
"Iteration 13 - loss: 0.31917455792427063\n", | |
"Played 100 episodes. Median reward: 10.0, Best reward: 16.0, Worst reward: 8.0\n", | |
"Selected 30 best episodes\n", | |
"Rewards: 16.0, 15.0, 15.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 10.0, 10.0\n", | |
"Iteration 14 - loss: 0.2883206605911255\n", | |
"Played 100 episodes. Median reward: 10.0, Best reward: 17.0, Worst reward: 8.0\n", | |
"Selected 30 best episodes\n", | |
"Rewards: 17.0, 17.0, 15.0, 15.0, 15.0, 14.0, 14.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0\n", | |
"Iteration 15 - loss: 0.33113226294517517\n", | |
"Played 100 episodes. Median reward: 10.0, Best reward: 19.0, Worst reward: 8.0\n", | |
"Selected 30 best episodes\n", | |
"Rewards: 19.0, 19.0, 18.0, 18.0, 16.0, 16.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0, 11.0, 11.0, 11.0\n", | |
"Iteration 16 - loss: 0.2757459580898285\n", | |
"Played 100 episodes. Median reward: 10.0, Best reward: 19.0, Worst reward: 8.0\n", | |
"Selected 30 best episodes\n", | |
"Rewards: 19.0, 18.0, 17.0, 16.0, 16.0, 16.0, 15.0, 14.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0, 11.0, 11.0, 11.0\n", | |
"Iteration 17 - loss: 0.28519508242607117\n", | |
"Played 100 episodes. Median reward: 10.0, Best reward: 15.0, Worst reward: 8.0\n", | |
"Selected 30 best episodes\n", | |
"Rewards: 15.0, 15.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0\n", | |
"Iteration 18 - loss: 0.2898578345775604\n" | |
] | |
}, | |
{ | |
"ename": "KeyboardInterrupt", | |
"evalue": "", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", | |
"Cell \u001b[0;32mIn[29], line 98\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIteration \u001b[39m\u001b[38;5;132;01m{\u001b[39;00miteration\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m - loss: \u001b[39m\u001b[38;5;124m\"\u001b[39m, loss\u001b[38;5;241m.\u001b[39mitem())\n\u001b[1;32m 95\u001b[0m iteration \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m---> 98\u001b[0m \u001b[43mtrain_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 100\u001b[0m env\u001b[38;5;241m.\u001b[39mclose()\n", | |
"Cell \u001b[0;32mIn[29], line 71\u001b[0m, in \u001b[0;36mtrain_model\u001b[0;34m()\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m median_reward \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m475\u001b[39m:\n\u001b[1;32m 69\u001b[0m \u001b[38;5;66;03m# Play with current model\u001b[39;00m\n\u001b[1;32m 70\u001b[0m episodes_generator \u001b[38;5;241m=\u001b[39m generate_episodes(\u001b[38;5;28;01mlambda\u001b[39;00m observation: sample_model_actions_distribution(model, observation))\n\u001b[0;32m---> 71\u001b[0m episodes \u001b[38;5;241m=\u001b[39m \u001b[43m[\u001b[49m\u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mepisodes_generator\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43m_\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mrange\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mBATCH_LEN\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 72\u001b[0m median_reward \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmedian([x[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m episodes])\n\u001b[1;32m 73\u001b[0m best_reward \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmax([x[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m episodes])\n", | |
"Cell \u001b[0;32mIn[29], line 71\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m median_reward \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m475\u001b[39m:\n\u001b[1;32m 69\u001b[0m \u001b[38;5;66;03m# Play with current model\u001b[39;00m\n\u001b[1;32m 70\u001b[0m episodes_generator \u001b[38;5;241m=\u001b[39m generate_episodes(\u001b[38;5;28;01mlambda\u001b[39;00m observation: sample_model_actions_distribution(model, observation))\n\u001b[0;32m---> 71\u001b[0m episodes \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mnext\u001b[39m(episodes_generator) \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m0\u001b[39m, BATCH_LEN)]\n\u001b[1;32m 72\u001b[0m median_reward \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmedian([x[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m episodes])\n\u001b[1;32m 73\u001b[0m best_reward \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmax([x[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m episodes])\n", | |
"Cell \u001b[0;32mIn[29], line 32\u001b[0m, in \u001b[0;36mgenerate_episodes\u001b[0;34m(predict)\u001b[0m\n\u001b[1;32m 30\u001b[0m next_action \u001b[38;5;241m=\u001b[39m env\u001b[38;5;241m.\u001b[39maction_space\u001b[38;5;241m.\u001b[39msample()\n\u001b[1;32m 31\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 32\u001b[0m next_action \u001b[38;5;241m=\u001b[39m \u001b[43mpredict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobservation\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 33\u001b[0m observation, reward, terminated, truncated, info \u001b[38;5;241m=\u001b[39m env\u001b[38;5;241m.\u001b[39mstep(next_action)\n\u001b[1;32m 34\u001b[0m episode\u001b[38;5;241m.\u001b[39mappend((observation, next_action))\n", | |
"Cell \u001b[0;32mIn[29], line 70\u001b[0m, in \u001b[0;36mtrain_model.<locals>.<lambda>\u001b[0;34m(observation)\u001b[0m\n\u001b[1;32m 67\u001b[0m median_reward \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 68\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m median_reward \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m475\u001b[39m:\n\u001b[1;32m 69\u001b[0m \u001b[38;5;66;03m# Play with current model\u001b[39;00m\n\u001b[0;32m---> 70\u001b[0m episodes_generator \u001b[38;5;241m=\u001b[39m generate_episodes(\u001b[38;5;28;01mlambda\u001b[39;00m observation: \u001b[43msample_model_actions_distribution\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mobservation\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 71\u001b[0m episodes \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mnext\u001b[39m(episodes_generator) \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m0\u001b[39m, BATCH_LEN)]\n\u001b[1;32m 72\u001b[0m median_reward \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmedian([x[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m episodes])\n", | |
"Cell \u001b[0;32mIn[29], line 45\u001b[0m, in \u001b[0;36msample_model_actions_distribution\u001b[0;34m(model, observation)\u001b[0m\n\u001b[1;32m 43\u001b[0m observation_minibatch \u001b[38;5;241m=\u001b[39m observation_tensor\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m 44\u001b[0m action_probability_distribution \u001b[38;5;241m=\u001b[39m dim_one_softmax(model(observation_minibatch))\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcpu\u001b[39m\u001b[38;5;124m'\u001b[39m)\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39mnumpy()[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m---> 45\u001b[0m action_sampled \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39mchoice(\u001b[38;5;28mlen\u001b[39m(action_probability_distribution), p\u001b[38;5;241m=\u001b[39maction_probability_distribution)\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m action_sampled\n", | |
"\u001b[0;31mKeyboardInterrupt\u001b[0m: " | |
] | |
} | |
], | |
"source": [ | |
"import gymnasium\n", | |
"import gymnasium as gym\n", | |
"from random import random\n", | |
"\n", | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import torch.optim as optim\n", | |
"\n", | |
"import numpy as np\n", | |
"\n", | |
"DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n", | |
"\n", | |
"env = gym.make(\"CartPole-v1\")\n", | |
"observation, info = env.reset()\n", | |
"\n", | |
"EPSILON = 0.05\n", | |
"EPISODE_LEN = 20\n", | |
"def generate_episodes(predict):\n", | |
" while True:\n", | |
" i = 0\n", | |
" truncated = False\n", | |
" terminated = False\n", | |
" \n", | |
" episode = list()\n", | |
" episode_reward = 0\n", | |
" \n", | |
" observation, info = env.reset()\n", | |
" while i < EPISODE_LEN and not truncated and not terminated:\n", | |
" if random() <= EPSILON:\n", | |
" next_action = env.action_space.sample()\n", | |
" else:\n", | |
" next_action = predict(observation)\n", | |
" observation, reward, terminated, truncated, info = env.step(next_action)\n", | |
" episode.append((observation, next_action))\n", | |
" episode_reward += reward\n", | |
" i += 1\n", | |
" \n", | |
" yield (episode, episode_reward)\n", | |
"\n", | |
"dim_one_softmax = nn.Softmax(dim=1)\n", | |
"def sample_model_actions_distribution(model, observation):\n", | |
" observation_tensor = torch.tensor(observation, dtype=torch.float32).to(DEVICE)\n", | |
" observation_minibatch = observation_tensor.unsqueeze(0)\n", | |
" action_probability_distribution = dim_one_softmax(model(observation_minibatch)).to('cpu').data.numpy()[0]\n", | |
" action_sampled = np.random.choice(len(action_probability_distribution), p=action_probability_distribution)\n", | |
" return action_sampled\n", | |
"\n", | |
" \n", | |
"BATCH_LEN = 100\n", | |
"HIDDEN_SIZE = 128\n", | |
"LEARNING_RATE = 0.01\n", | |
"TAKE_TOP_P = 0.3 # Best 20% of episodes used for training\n", | |
"def train_model():\n", | |
" obs_size = env.observation_space.shape[0]\n", | |
" n_actions = int(env.action_space.n)\n", | |
" model = Net(\n", | |
" obs_size=obs_size,\n", | |
" actions_size=n_actions,\n", | |
" hidden_layers=HIDDEN_SIZE\n", | |
" ).to(DEVICE)\n", | |
"\n", | |
" objective = nn.CrossEntropyLoss()\n", | |
" optimizer = optim.Adam(params=model.parameters(), lr=LEARNING_RATE)\n", | |
" \n", | |
" # Train model\n", | |
" iteration = 0\n", | |
" median_reward = 0\n", | |
" while median_reward < 475:\n", | |
" # Play with current model\n", | |
" episodes_generator = generate_episodes(lambda observation: sample_model_actions_distribution(model, observation))\n", | |
" episodes = [next(episodes_generator) for _ in range(0, BATCH_LEN)]\n", | |
" median_reward = np.median([x[1] for x in episodes])\n", | |
" best_reward = np.max([x[1] for x in episodes])\n", | |
" worst_reward = np.min([x[1] for x in episodes])\n", | |
" print(f\"Played {BATCH_LEN} episodes. Median reward: {median_reward}, Best reward: {best_reward}, Worst reward: {worst_reward}\")\n", | |
"\n", | |
" # Pick best p episodes\n", | |
" episodes_sorted = sorted(episodes, key=lambda x: x[1], reverse=True)\n", | |
" episodes_top_p = episodes_sorted[0:int(TAKE_TOP_P * BATCH_LEN)]\n", | |
" print(f\"Selected {len(episodes_top_p)} best episodes\")\n", | |
" print(f\"Rewards: {', '.join([str(x[1]) for x in episodes_top_p])}\")\n", | |
"\n", | |
" # Train the model on the best (obs, action) pairs. Episodes is a list of ((obs, action), total_reward) pairs\n", | |
" pairs = [x[0] for x in episodes_top_p]\n", | |
" flat_pairs = [item for sublist in pairs for item in sublist]\n", | |
" minibatch_observations = torch.tensor([pair[0] for pair in flat_pairs], dtype=torch.float32).to(DEVICE)\n", | |
" minibatch_actions = torch.tensor([pair[1] for pair in flat_pairs], dtype=torch.long).to(DEVICE)\n", | |
"\n", | |
" optimizer.zero_grad()\n", | |
" predicted_actions = model(minibatch_observations)\n", | |
" loss = objective(predicted_actions, minibatch_actions) # CrossEntropyLoss -> difference between predicted and actual actions\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
" print(f\"Iteration {iteration} - loss: \", loss.item())\n", | |
" iteration += 1\n", | |
"\n", | |
" \n", | |
"train_model()\n", | |
"\n", | |
"env.close()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "f2b0cde5-257b-4d90-b073-296db8ae8e2c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"1" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "4369934b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment