Last active
February 25, 2017 05:55
-
-
Save tall-josh/63411cec48eb9efe7afae83e452be307 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Reinforcement Learning For Self Driving Cars" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"At the moment this code does not work because I'm very new to jupyter notebooks and I haven't figured out how to import classes." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"leave_program = False\n", | |
"total_frames = 0\n", | |
"epochs = 50000\n", | |
"epoch_cnt = 0\n", | |
"gamma = 0.9\n", | |
"epsilon = 1\n", | |
"batch_size = 30\n", | |
"buffer = 30000\n", | |
"replay = []\n", | |
"h = 0\n", | |
"reward = 0\n", | |
"\n", | |
"for i in range(epochs):\n", | |
" \n", | |
" initSimulation(agent, state, filling_buffer = True if len(replay) < buffer else False)\n", | |
" collision_detected = False\n", | |
" frames_this_epoch = 0\n", | |
"\n", | |
" while not collision_detected:\n", | |
" # concatonating a string that I print at the bottom of the loop\n", | |
" __console_string = \"\"\n", | |
" __console_string += \"FRAME: {0} -- \".format(frames_this_epoch)\n", | |
" \n", | |
" ##### PYGAME HOUSE KEEPING #####\n", | |
" # Keep loop time constant\n", | |
" clock.tick(CONST.SCREEN_FPS)\n", | |
" screen.fill(CONST.COLOR_BLACK)\n", | |
"\n", | |
" # Returns quality estimates for all posiable actions, copies to state_0\n", | |
" qMatrix = dqnn.getQMat(state.state)\n", | |
" state_0 = copy.deepcopy(state.state)\n", | |
" \n", | |
" ##### SELECT ACTION #####\n", | |
" # Select random action or use best action from qMatrix\n", | |
" action_idx = 0\n", | |
" if (random.random() < epsilon):\n", | |
" action_idx = random.randint(0,len(CONST.ACTION_AND_COSTS)-1)\n", | |
" __console_string += \"random action: {0} -- \".format(CONST.ACTION_NAMES[action_idx])\n", | |
" else:\n", | |
" action_idx = np.argmax(qMatrix)\n", | |
" __console_string += \"selected action: {0} -- \".format(CONST.ACTION_NAMES[action_idx])\n", | |
"\n", | |
" ##### Take action #####\n", | |
" agent.updateAction(action_idx) # Apply action selected above\n", | |
" all_sprites.update() \n", | |
" __console_string += \"speed: {0} -- \".format(agent.speed)\n", | |
" \n", | |
" # Check for agent obstacle collisions\n", | |
" collisions = pygame.sprite.spritecollide(agent, obstacles, False) \n", | |
" \n", | |
"\n", | |
" ##### Observe new state (s') #####\n", | |
" agent.updateSensors(obstacles) # Sensor update\n", | |
" state.update(agent.sensor_data) # Update state with new data\n", | |
"\n", | |
" ##### GET maxQ' from DQN #####\n", | |
" next_qMatrix = dqnn.getQMat(state.state)\n", | |
"\n", | |
" # Get reward from agent\n", | |
" reward = agent.reward\n", | |
"\n", | |
" if (collisions or agent.out_of_bounds):\n", | |
" collision_detected = True\n", | |
" reward = CONST.REWARDS['terminal_crash']\n", | |
" print(\"terminal_crash*********************************************************************************\")\n", | |
" \n", | |
" if agent.isAtGoal():\n", | |
" collision_detected = True\n", | |
" reward = CONST.REWARDS['terminal_goal']\n", | |
" agent.terminal = True\n", | |
" print(\"terminal_goal!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\")\n", | |
" \n", | |
" if frames_this_epoch > CONST.TAKING_TOO_LONG:\n", | |
" collision_detected = True\n", | |
" print(\"TAKING TOO LONG :-( :-( :-( :-( :-( :-( :-( :-( :-( :-( :-( :-( :-( :-( :-( :-( :-( :-( :-( :-( \")\n", | |
" __console_string += \"Reward: {0} -- Epsilon: {1} -- Epoch: {2} -- Total_Frames: {3}\".format(reward, epsilon, epoch_cnt, total_frames)\n", | |
" \n", | |
"\t\t# if the buffer is not full, keep filling. Else, overwrite oldest element begin learning\n", | |
" if len(replay) < buffer:\n", | |
" replay.append((copy.deepcopy(state_0), copy.deepcopy(action_idx), copy.deepcopy(reward), copy.deepcopy(state.state)))\n", | |
" \n", | |
" if log_data:\n", | |
" # Data Logging \n", | |
" states0.append(copy.deepcopy(state_0))\n", | |
" actions.append(copy.deepcopy(action_idx))\n", | |
" rewards.append(copy.deepcopy(reward))\n", | |
" states1.append(copy.deepcopy(state.state))\n", | |
"\n", | |
" epoch_cnt = 0# keep epochs at zero until buffer if full\n", | |
" else:\n", | |
" h = (h+1)%buffer\n", | |
" replay.append((copy.deepcopy(state_0), copy.deepcopy(action_idx), copy.deepcopy(reward), copy.deepcopy(state.state)))\n", | |
" \n", | |
" batch = random.sample(replay, batch_size)\n", | |
" target_batch = []\n", | |
" for element in batch:\n", | |
" # Breakup tuple into more readable elements\n", | |
"\t\t\t\treplay_state, replay_action_idx, replay_reward, replay_new_state = element\n", | |
"\t\t\t\t\n", | |
"\t\t\t\t# First feed forward pass to get qualities of each action in initial state\n", | |
" q_mat = dqnn.getQMat(replay_state)\n", | |
"\t\t\t\t\n", | |
"\t\t\t\t# this will be used as the 'target' for back-prop \n", | |
" y = copy.deepcopy(q_mat[0])\n", | |
"\n", | |
"\t\t\t\t# Second feed forward pass to get qualities of each action in new state\n", | |
" q_mat_new = dqnn.getQMat(replay_new_state)[0]\n", | |
" q_val_new = max(q_mat_new)\n", | |
"\n", | |
"\t\t\t\t# Negative 1 rewards indicate a terminal state (collision)\n", | |
" if (replay_reward == -1):\n", | |
" q_update = replay_reward\n", | |
" else:\n", | |
" q_update = replay_reward + (gamma*q_val_new)\n", | |
" \n", | |
"\t\t\t\t# Overwrite index of action taken with q_update\n", | |
"\t\t\t\ty[replay_action_idx] = q_update\n", | |
" target_batch.append(y)\n", | |
"\n", | |
"\t\t\t# writing various things to the screen for debugging \n", | |
" if total_frames % 100 == 0:\n", | |
" dqnn.fitBatch([row[0] for row in batch], target_batch, save=False, verbose=True, iteration_count=total_frames-buffer)\n", | |
" elif total_frames % 10001 == 0:\n", | |
" dqnn.fitBatch([row[0] for row in batch], target_batch, save=True, verbose=False, iteration_count=total_frames-buffer)\n", | |
" else:\n", | |
" dqnn.fitBatch([row[0] for row in batch], target_batch)\n", | |
"\n", | |
"\t\t\t# Decreasing epsilon\n", | |
" if epsilon > 0.1:\n", | |
" epsilon -= 1/100000\n", | |
"\n", | |
"\n", | |
"##### MORE PYGAME HOUSE KEEPING #####\n", | |
"# Respawn obstacles if they are - CONST. out of range\n", | |
" if moving_obstacles:\n", | |
" for obs in obstacles:\n", | |
" if obs.out_of_range:\n", | |
" obs.reInitObs(0, CONST.LANES[random.rand(CONST.CAR_LANE_MIN,CONST.CAR_LANE_MAX)], obstacles)\n", | |
" #print(\"Reload\")\n", | |
" \n", | |
"# Check if agent is out of bounds\n", | |
" if agent.rect.x > CONST.SCREEN_WIDTH + CONST.SCREEN_PADDING:\n", | |
" collision_detected = True\n", | |
"\n", | |
"# Draw / render\n", | |
" all_sprites.draw(screen)\n", | |
"### Drawing lane markers\n", | |
" center_guard = CONST.LANES[3] + CONST.LANE_WIDTH//2\n", | |
" color = CONST.COLOR_ORANGE\n", | |
" for lane in CONST.LANES:\n", | |
" pygame.draw.line(screen, color, (0, lane-CONST.LANE_WIDTH//2), (CONST.SCREEN_WIDTH, lane-CONST.LANE_WIDTH//2))\n", | |
" color = CONST.COLOR_WHITE\n", | |
" pygame.draw.line(screen, CONST.COLOR_ORANGE, (0, CONST.LANES[len(CONST.LANES)-1] + CONST.LANE_WIDTH//2), (CONST.SCREEN_WIDTH, CONST.LANES[len(CONST.LANES)-1] + CONST.LANE_WIDTH//2))\n", | |
"\n", | |
"## Draw carrot (what the PID follows track lanes)\n", | |
" pygame.draw.circle(screen, CONST.COLOR_ORANGE, (agent.carrot), 5)\n", | |
" pygame.draw.circle(screen, CONST.COLOR_ORANGE, (300, int(CONST.LANES[3] + CONST.LANE_WIDTH//2)), 4)\n", | |
"#\n", | |
"## Draw most recent LiD\n", | |
" for beam in agent.lidar.beams:\n", | |
" pygame.draw.line(screen, beam.color, (beam.x1, beam.y1), (agent.rect.centerx, agent.rect.centery))\n", | |
"\n", | |
"# Plot lidar data in console if state.setLivePlot\n", | |
" if state.setLivePlot:\n", | |
" print(\"PLOTTING\")\n", | |
" state.plotState(True)\n", | |
" \n", | |
" \n", | |
"##### For if I decide to put some user input fuctionality #####\n", | |
"# Process input (events)\n", | |
" for event in pygame.event.get():\n", | |
"# Check for closing window\n", | |
" if event.type == pygame.QUIT:\n", | |
" collision_detected = True\n", | |
" leave_program = True\n", | |
" dqnn.session.close()\n", | |
" if event.type == pygame.KEYDOWN:\n", | |
" if event.key == pygame.K_p:\n", | |
" state.setLivePlot = not state.setLivePlot #toggle live plotting\n", | |
" action_idx = 1\n", | |
" if event.key == pygame.K_UP:\n", | |
" __console_data_print_frequency += 1\n", | |
" print(\"Print Frequ every {0} frames\".format(__console_data_print_frequency))\n", | |
" if event.key == pygame.K_DOWN:\n", | |
" __console_data_print_frequency -= 1\n", | |
" print(\"Print Frequ every {0} frames\".format(__console_data_print_frequency))\n", | |
" if event.key == pygame.K_RIGHT:\n", | |
" __console_data_print_frequency += 10\n", | |
" print(\"Print Frequ every {0} frames\".format(__console_data_print_frequency))\n", | |
" if event.key == pygame.K_LEFT:\n", | |
" __console_data_print_frequency -= 10\n", | |
" print(\"Print Frequ every {0} frames\".format(__console_data_print_frequency))\n", | |
" if event.key == pygame.K_q:\n", | |
" epsilon += 0.05\n", | |
" if epsilon > 1: epsilon = 1\n", | |
" print(\"Epsilon now: {0}\".format(epsilon))\n", | |
" if event.key == pygame.K_a:\n", | |
" epsilon -= 0.05\n", | |
" if epsilon < 0.1: epsilon = 0.1\n", | |
" print(\"Epsilon now: {0}\".format(epsilon))\n", | |
" \n", | |
" if __console_data_print_frequency <= 0: __console_data_print_frequency = 1\n", | |
" \n", | |
" if total_frames % __console_data_print_frequency == 0:\n", | |
" print(__console_string, os.linesep)\n", | |
" print(\"q_matrix: {0} -- \".format(qMatrix))\n", | |
" print(\"_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _\", os.linesep)\n", | |
" frames_this_epoch += 1\n", | |
" total_frames += 1\n", | |
"# After everything, flip display\n", | |
" pygame.display.flip()\n", | |
"\n", | |
" epoch_cnt += 1\n", | |
" if epoch_cnt == epochs-1: epoch_cnt = epochs - 2 \n", | |
" if leave_program: break\n", | |
" if log_data:\n", | |
" log.logData(fileNames, toLog)\n", | |
" # Data Logging \n", | |
" states0.clear()\n", | |
" actions.clear()\n", | |
" rewards.clear()\n", | |
" states1.clear()\n", | |
"\n", | |
"dqnn.session.close()\n", | |
"pygame.quit();" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python [conda root]", | |
"language": "python", | |
"name": "conda-root-py" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 1 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment