Created
August 30, 2024 18:49
-
-
Save BexTuychiev/1296693d7b50e000aaecf894c3d9537d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gymnasium as gym | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from moviepy.editor import ImageSequenceClip | |
def create_environment(env_name='Taxi-v3', render_mode='rgb_array'): | |
"""Create and return a Gymnasium environment.""" | |
return gym.make(env_name, render_mode=render_mode) | |
def initialize_q_table(env): | |
"""Initialize and return a Q-table for the given environment.""" | |
n_states = env.observation_space.n | |
n_actions = env.action_space.n | |
return np.zeros((n_states, n_actions)) | |
def epsilon_greedy(env, Q_table, state, epsilon): | |
"""Epsilon-greedy action selection.""" | |
if np.random.random() < epsilon: | |
return env.action_space.sample() | |
else: | |
return np.argmax(Q_table[state]) | |
def sarsa_update(Q_table, state, action, reward, next_state, next_action, alpha, gamma): | |
"""Perform SARSA update on Q-table.""" | |
Q_table[state, action] += alpha * ( | |
reward + gamma * Q_table[next_state, next_action] - Q_table[state, action] | |
) | |
def train_sarsa(env, n_episodes=20000, alpha=0.1, gamma=0.99, epsilon=0.1): | |
"""Train the agent using SARSA algorithm.""" | |
Q_table = initialize_q_table(env) | |
episode_rewards = [] | |
episode_lengths = [] | |
for episode in range(n_episodes): | |
state, _ = env.reset() | |
action = epsilon_greedy(env, Q_table, state, epsilon) | |
done = False | |
total_reward = 0 | |
steps = 0 | |
while not done: | |
next_state, reward, terminated, truncated, _ = env.step(action) | |
done = terminated or truncated | |
next_action = epsilon_greedy(env, Q_table, next_state, epsilon) | |
sarsa_update(Q_table, state, action, reward, next_state, next_action, alpha, gamma) | |
state = next_state | |
action = next_action | |
total_reward += reward | |
steps += 1 | |
episode_rewards.append(total_reward) | |
episode_lengths.append(steps) | |
if episode % 2000 == 0: | |
avg_reward = np.mean(episode_rewards[-1000:]) | |
avg_length = np.mean(episode_lengths[-1000:]) | |
print(f"Episode {episode}, Avg Reward: {avg_reward:.2f}, Avg Length: {avg_length:.2f}") | |
return Q_table, episode_rewards, episode_lengths | |
def plot_learning_curve(episode_rewards, episode_lengths): | |
"""Plot the learning curve.""" | |
plt.figure(figsize=(12, 5)) | |
plt.subplot(1, 2, 1) | |
plt.plot(episode_rewards) | |
plt.title("Episode Rewards") | |
plt.xlabel("Episode") | |
plt.ylabel("Total Reward") | |
plt.subplot(1, 2, 2) | |
plt.plot(episode_lengths) | |
plt.title("Episode Lengths") | |
plt.xlabel("Episode") | |
plt.ylabel("Number of Steps") | |
plt.tight_layout() | |
plt.show() | |
def create_gif(frames, filename, fps=5): | |
"""Creates a GIF animation from a list of frames.""" | |
clip = ImageSequenceClip(frames, fps=fps) | |
clip.write_gif(filename, fps=fps) | |
def run_episode(env, Q_table, epsilon=0): | |
"""Run a single episode using the learned Q-table.""" | |
state, _ = env.reset() | |
done = False | |
total_reward = 0 | |
frames = [env.render()] | |
while not done: | |
action = epsilon_greedy(env, Q_table, state, epsilon) | |
next_state, reward, terminated, truncated, _ = env.step(action) | |
done = terminated or truncated | |
frames.append(env.render()) | |
total_reward += reward | |
state = next_state | |
return frames, total_reward | |
# Main execution | |
if __name__ == "__main__": | |
env = create_environment() | |
Q_table, episode_rewards, episode_lengths = train_sarsa(env) | |
plot_learning_curve(episode_rewards, episode_lengths) | |
# Run a single episode with the trained Q-table | |
frames, total_reward = run_episode(env, Q_table) | |
create_gif(frames, "trained_taxi.gif", fps=5) | |
print(f"Episode completed with total reward: {total_reward}") | |
env.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment