Skip to content

Instantly share code, notes, and snippets.

@ValentinFunk
Last active January 19, 2025 12:47
Show Gist options
  • Save ValentinFunk/f8db747ab6ad8782b48aba0d0db6f565 to your computer and use it in GitHub Desktop.
Save ValentinFunk/f8db747ab6ad8782b48aba0d0db6f565 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "bd0b6703-2f23-4f57-83aa-f6eb323e6958",
"metadata": {
"execution": {
"iopub.execute_input": "2025-01-19T11:19:59.565050Z",
"iopub.status.busy": "2025-01-19T11:19:59.563942Z",
"iopub.status.idle": "2025-01-19T11:19:59.651501Z",
"shell.execute_reply": "2025-01-19T11:19:59.650376Z",
"shell.execute_reply.started": "2025-01-19T11:19:59.565050Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"<pyvirtualdisplay.display.Display at 0x7f9c60386f10>"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Virtual display\n",
"from pyvirtualdisplay import Display\n",
"\n",
"virtual_display = Display(visible=0, size=(1400, 900))\n",
"virtual_display.start()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "9f9bf992-5c21-4c05-a87e-69e1dcabcec6",
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"\n",
"class Net(nn.Module):\n",
" def __init__(self, obs_size, actions_size, hidden_layers):\n",
" super(Net, self).__init__()\n",
" \n",
" self.model = nn.Sequential(\n",
" nn.Linear(obs_size, hidden_layers),\n",
" nn.ReLU(),\n",
" nn.Linear(hidden_layers, actions_size)\n",
" )\n",
"\n",
" def forward(self, x):\n",
" return self.model(x)\n"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "2fe6e83c-2a5e-4670-afd6-e6ef43ba094e",
"metadata": {
"execution": {
"iopub.execute_input": "2025-01-19T11:22:38.199101Z",
"iopub.status.busy": "2025-01-19T11:22:38.197219Z",
"iopub.status.idle": "2025-01-19T11:22:38.214770Z",
"shell.execute_reply": "2025-01-19T11:22:38.214093Z",
"shell.execute_reply.started": "2025-01-19T11:22:38.198884Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Played 100 episodes. Median reward: 20.0, Best reward: 20.0, Worst reward: 10.0\n",
"Selected 30 best episodes\n",
"Rewards: 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0\n",
"Iteration 0 - loss: 0.7031133770942688\n",
"Played 100 episodes. Median reward: 15.5, Best reward: 20.0, Worst reward: 9.0\n",
"Selected 30 best episodes\n",
"Rewards: 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 19.0, 19.0\n",
"Iteration 1 - loss: 0.654332160949707\n",
"Played 100 episodes. Median reward: 14.0, Best reward: 20.0, Worst reward: 8.0\n",
"Selected 30 best episodes\n",
"Rewards: 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 19.0, 19.0, 19.0, 18.0, 18.0, 18.0, 18.0, 18.0, 18.0, 17.0, 17.0, 16.0\n",
"Iteration 2 - loss: 0.5910254120826721\n",
"Played 100 episodes. Median reward: 12.0, Best reward: 20.0, Worst reward: 8.0\n",
"Selected 30 best episodes\n",
"Rewards: 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 19.0, 19.0, 19.0, 19.0, 19.0, 17.0, 17.0, 17.0, 16.0, 16.0, 16.0, 16.0, 16.0, 15.0, 15.0, 15.0, 15.0, 15.0, 14.0, 14.0\n",
"Iteration 3 - loss: 0.5180516242980957\n",
"Played 100 episodes. Median reward: 12.0, Best reward: 20.0, Worst reward: 8.0\n",
"Selected 30 best episodes\n",
"Rewards: 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 19.0, 19.0, 19.0, 19.0, 18.0, 18.0, 17.0, 17.0, 17.0, 17.0, 17.0, 17.0, 17.0, 17.0, 16.0, 16.0, 16.0, 16.0, 16.0, 15.0, 15.0, 15.0\n",
"Iteration 4 - loss: 0.4938879609107971\n",
"Played 100 episodes. Median reward: 11.5, Best reward: 20.0, Worst reward: 8.0\n",
"Selected 30 best episodes\n",
"Rewards: 20.0, 20.0, 20.0, 20.0, 20.0, 19.0, 18.0, 17.0, 16.0, 16.0, 16.0, 15.0, 15.0, 15.0, 15.0, 15.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0\n",
"Iteration 5 - loss: 0.38373544812202454\n",
"Played 100 episodes. Median reward: 11.0, Best reward: 20.0, Worst reward: 8.0\n",
"Selected 30 best episodes\n",
"Rewards: 20.0, 20.0, 18.0, 18.0, 18.0, 17.0, 16.0, 16.0, 16.0, 16.0, 16.0, 15.0, 15.0, 15.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0\n",
"Iteration 6 - loss: 0.39132174849510193\n",
"Played 100 episodes. Median reward: 11.0, Best reward: 20.0, Worst reward: 8.0\n",
"Selected 30 best episodes\n",
"Rewards: 20.0, 20.0, 20.0, 18.0, 18.0, 17.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 15.0, 15.0, 15.0, 15.0, 15.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0\n",
"Iteration 7 - loss: 0.35876840353012085\n",
"Played 100 episodes. Median reward: 10.0, Best reward: 20.0, Worst reward: 8.0\n",
"Selected 30 best episodes\n",
"Rewards: 20.0, 18.0, 18.0, 18.0, 15.0, 15.0, 15.0, 14.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0\n",
"Iteration 8 - loss: 0.3280767500400543\n",
"Played 100 episodes. Median reward: 10.0, Best reward: 20.0, Worst reward: 8.0\n",
"Selected 30 best episodes\n",
"Rewards: 20.0, 18.0, 16.0, 15.0, 15.0, 15.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0\n",
"Iteration 9 - loss: 0.3035016357898712\n",
"Played 100 episodes. Median reward: 10.0, Best reward: 20.0, Worst reward: 8.0\n",
"Selected 30 best episodes\n",
"Rewards: 20.0, 20.0, 19.0, 16.0, 15.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0, 11.0\n",
"Iteration 10 - loss: 0.23971830308437347\n",
"Played 100 episodes. Median reward: 10.0, Best reward: 20.0, Worst reward: 8.0\n",
"Selected 30 best episodes\n",
"Rewards: 20.0, 19.0, 18.0, 16.0, 16.0, 15.0, 15.0, 15.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0\n",
"Iteration 11 - loss: 0.36385658383369446\n",
"Played 100 episodes. Median reward: 10.0, Best reward: 16.0, Worst reward: 8.0\n",
"Selected 30 best episodes\n",
"Rewards: 16.0, 15.0, 15.0, 14.0, 14.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0\n",
"Iteration 12 - loss: 0.24887436628341675\n",
"Played 100 episodes. Median reward: 10.0, Best reward: 16.0, Worst reward: 8.0\n",
"Selected 30 best episodes\n",
"Rewards: 16.0, 16.0, 15.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0\n",
"Iteration 13 - loss: 0.31917455792427063\n",
"Played 100 episodes. Median reward: 10.0, Best reward: 16.0, Worst reward: 8.0\n",
"Selected 30 best episodes\n",
"Rewards: 16.0, 15.0, 15.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 10.0, 10.0\n",
"Iteration 14 - loss: 0.2883206605911255\n",
"Played 100 episodes. Median reward: 10.0, Best reward: 17.0, Worst reward: 8.0\n",
"Selected 30 best episodes\n",
"Rewards: 17.0, 17.0, 15.0, 15.0, 15.0, 14.0, 14.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0\n",
"Iteration 15 - loss: 0.33113226294517517\n",
"Played 100 episodes. Median reward: 10.0, Best reward: 19.0, Worst reward: 8.0\n",
"Selected 30 best episodes\n",
"Rewards: 19.0, 19.0, 18.0, 18.0, 16.0, 16.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0, 11.0, 11.0, 11.0\n",
"Iteration 16 - loss: 0.2757459580898285\n",
"Played 100 episodes. Median reward: 10.0, Best reward: 19.0, Worst reward: 8.0\n",
"Selected 30 best episodes\n",
"Rewards: 19.0, 18.0, 17.0, 16.0, 16.0, 16.0, 15.0, 14.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0, 11.0, 11.0, 11.0\n",
"Iteration 17 - loss: 0.28519508242607117\n",
"Played 100 episodes. Median reward: 10.0, Best reward: 15.0, Worst reward: 8.0\n",
"Selected 30 best episodes\n",
"Rewards: 15.0, 15.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0\n",
"Iteration 18 - loss: 0.2898578345775604\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[29], line 98\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIteration \u001b[39m\u001b[38;5;132;01m{\u001b[39;00miteration\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m - loss: \u001b[39m\u001b[38;5;124m\"\u001b[39m, loss\u001b[38;5;241m.\u001b[39mitem())\n\u001b[1;32m 95\u001b[0m iteration \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m---> 98\u001b[0m \u001b[43mtrain_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 100\u001b[0m env\u001b[38;5;241m.\u001b[39mclose()\n",
"Cell \u001b[0;32mIn[29], line 71\u001b[0m, in \u001b[0;36mtrain_model\u001b[0;34m()\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m median_reward \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m475\u001b[39m:\n\u001b[1;32m 69\u001b[0m \u001b[38;5;66;03m# Play with current model\u001b[39;00m\n\u001b[1;32m 70\u001b[0m episodes_generator \u001b[38;5;241m=\u001b[39m generate_episodes(\u001b[38;5;28;01mlambda\u001b[39;00m observation: sample_model_actions_distribution(model, observation))\n\u001b[0;32m---> 71\u001b[0m episodes \u001b[38;5;241m=\u001b[39m \u001b[43m[\u001b[49m\u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mepisodes_generator\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43m_\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mrange\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mBATCH_LEN\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 72\u001b[0m median_reward \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmedian([x[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m episodes])\n\u001b[1;32m 73\u001b[0m best_reward \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmax([x[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m episodes])\n",
"Cell \u001b[0;32mIn[29], line 71\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m median_reward \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m475\u001b[39m:\n\u001b[1;32m 69\u001b[0m \u001b[38;5;66;03m# Play with current model\u001b[39;00m\n\u001b[1;32m 70\u001b[0m episodes_generator \u001b[38;5;241m=\u001b[39m generate_episodes(\u001b[38;5;28;01mlambda\u001b[39;00m observation: sample_model_actions_distribution(model, observation))\n\u001b[0;32m---> 71\u001b[0m episodes \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mnext\u001b[39m(episodes_generator) \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m0\u001b[39m, BATCH_LEN)]\n\u001b[1;32m 72\u001b[0m median_reward \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmedian([x[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m episodes])\n\u001b[1;32m 73\u001b[0m best_reward \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmax([x[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m episodes])\n",
"Cell \u001b[0;32mIn[29], line 32\u001b[0m, in \u001b[0;36mgenerate_episodes\u001b[0;34m(predict)\u001b[0m\n\u001b[1;32m 30\u001b[0m next_action \u001b[38;5;241m=\u001b[39m env\u001b[38;5;241m.\u001b[39maction_space\u001b[38;5;241m.\u001b[39msample()\n\u001b[1;32m 31\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 32\u001b[0m next_action \u001b[38;5;241m=\u001b[39m \u001b[43mpredict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobservation\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 33\u001b[0m observation, reward, terminated, truncated, info \u001b[38;5;241m=\u001b[39m env\u001b[38;5;241m.\u001b[39mstep(next_action)\n\u001b[1;32m 34\u001b[0m episode\u001b[38;5;241m.\u001b[39mappend((observation, next_action))\n",
"Cell \u001b[0;32mIn[29], line 70\u001b[0m, in \u001b[0;36mtrain_model.<locals>.<lambda>\u001b[0;34m(observation)\u001b[0m\n\u001b[1;32m 67\u001b[0m median_reward \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 68\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m median_reward \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m475\u001b[39m:\n\u001b[1;32m 69\u001b[0m \u001b[38;5;66;03m# Play with current model\u001b[39;00m\n\u001b[0;32m---> 70\u001b[0m episodes_generator \u001b[38;5;241m=\u001b[39m generate_episodes(\u001b[38;5;28;01mlambda\u001b[39;00m observation: \u001b[43msample_model_actions_distribution\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mobservation\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 71\u001b[0m episodes \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mnext\u001b[39m(episodes_generator) \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m0\u001b[39m, BATCH_LEN)]\n\u001b[1;32m 72\u001b[0m median_reward \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmedian([x[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m episodes])\n",
"Cell \u001b[0;32mIn[29], line 45\u001b[0m, in \u001b[0;36msample_model_actions_distribution\u001b[0;34m(model, observation)\u001b[0m\n\u001b[1;32m 43\u001b[0m observation_minibatch \u001b[38;5;241m=\u001b[39m observation_tensor\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m 44\u001b[0m action_probability_distribution \u001b[38;5;241m=\u001b[39m dim_one_softmax(model(observation_minibatch))\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcpu\u001b[39m\u001b[38;5;124m'\u001b[39m)\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39mnumpy()[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m---> 45\u001b[0m action_sampled \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39mchoice(\u001b[38;5;28mlen\u001b[39m(action_probability_distribution), p\u001b[38;5;241m=\u001b[39maction_probability_distribution)\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m action_sampled\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"import gymnasium\n",
"import gymnasium as gym\n",
"from random import random\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"\n",
"import numpy as np\n",
"\n",
"DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"\n",
"env = gym.make(\"CartPole-v1\")\n",
"observation, info = env.reset()\n",
"\n",
"EPSILON = 0.05\n",
"EPISODE_LEN = 20\n",
"def generate_episodes(predict):\n",
" while True:\n",
" i = 0\n",
" truncated = False\n",
" terminated = False\n",
" \n",
" episode = list()\n",
" episode_reward = 0\n",
" \n",
" observation, info = env.reset()\n",
" while i < EPISODE_LEN and not truncated and not terminated:\n",
" if random() <= EPSILON:\n",
" next_action = env.action_space.sample()\n",
" else:\n",
" next_action = predict(observation)\n",
" observation, reward, terminated, truncated, info = env.step(next_action)\n",
" episode.append((observation, next_action))\n",
" episode_reward += reward\n",
" i += 1\n",
" \n",
" yield (episode, episode_reward)\n",
"\n",
"dim_one_softmax = nn.Softmax(dim=1)\n",
"def sample_model_actions_distribution(model, observation):\n",
" observation_tensor = torch.tensor(observation, dtype=torch.float32).to(DEVICE)\n",
" observation_minibatch = observation_tensor.unsqueeze(0)\n",
" action_probability_distribution = dim_one_softmax(model(observation_minibatch)).to('cpu').data.numpy()[0]\n",
" action_sampled = np.random.choice(len(action_probability_distribution), p=action_probability_distribution)\n",
" return action_sampled\n",
"\n",
" \n",
"BATCH_LEN = 100\n",
"HIDDEN_SIZE = 128\n",
"LEARNING_RATE = 0.01\n",
"TAKE_TOP_P = 0.3 # Best 20% of episodes used for training\n",
"def train_model():\n",
" obs_size = env.observation_space.shape[0]\n",
" n_actions = int(env.action_space.n)\n",
" model = Net(\n",
" obs_size=obs_size,\n",
" actions_size=n_actions,\n",
" hidden_layers=HIDDEN_SIZE\n",
" ).to(DEVICE)\n",
"\n",
" objective = nn.CrossEntropyLoss()\n",
" optimizer = optim.Adam(params=model.parameters(), lr=LEARNING_RATE)\n",
" \n",
" # Train model\n",
" iteration = 0\n",
" median_reward = 0\n",
" while median_reward < 475:\n",
" # Play with current model\n",
" episodes_generator = generate_episodes(lambda observation: sample_model_actions_distribution(model, observation))\n",
" episodes = [next(episodes_generator) for _ in range(0, BATCH_LEN)]\n",
" median_reward = np.median([x[1] for x in episodes])\n",
" best_reward = np.max([x[1] for x in episodes])\n",
" worst_reward = np.min([x[1] for x in episodes])\n",
" print(f\"Played {BATCH_LEN} episodes. Median reward: {median_reward}, Best reward: {best_reward}, Worst reward: {worst_reward}\")\n",
"\n",
" # Pick best p episodes\n",
" episodes_sorted = sorted(episodes, key=lambda x: x[1], reverse=True)\n",
" episodes_top_p = episodes_sorted[0:int(TAKE_TOP_P * BATCH_LEN)]\n",
" print(f\"Selected {len(episodes_top_p)} best episodes\")\n",
" print(f\"Rewards: {', '.join([str(x[1]) for x in episodes_top_p])}\")\n",
"\n",
" # Train the model on the best (obs, action) pairs. Episodes is a list of ((obs, action), total_reward) pairs\n",
" pairs = [x[0] for x in episodes_top_p]\n",
" flat_pairs = [item for sublist in pairs for item in sublist]\n",
" minibatch_observations = torch.tensor([pair[0] for pair in flat_pairs], dtype=torch.float32).to(DEVICE)\n",
" minibatch_actions = torch.tensor([pair[1] for pair in flat_pairs], dtype=torch.long).to(DEVICE)\n",
"\n",
" optimizer.zero_grad()\n",
" predicted_actions = model(minibatch_observations)\n",
" loss = objective(predicted_actions, minibatch_actions) # CrossEntropyLoss -> difference between predicted and actual actions\n",
" loss.backward()\n",
" optimizer.step()\n",
" print(f\"Iteration {iteration} - loss: \", loss.item())\n",
" iteration += 1\n",
"\n",
" \n",
"train_model()\n",
"\n",
"env.close()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "f2b0cde5-257b-4d90-b073-296db8ae8e2c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "4369934b",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment