Skip to content

Instantly share code, notes, and snippets.

@BexTuychiev
Created August 30, 2024 18:49
Show Gist options
  • Save BexTuychiev/1296693d7b50e000aaecf894c3d9537d to your computer and use it in GitHub Desktop.
Save BexTuychiev/1296693d7b50e000aaecf894c3d9537d to your computer and use it in GitHub Desktop.
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
from moviepy.editor import ImageSequenceClip
def create_environment(env_name='Taxi-v3', render_mode='rgb_array'):
"""Create and return a Gymnasium environment."""
return gym.make(env_name, render_mode=render_mode)
def initialize_q_table(env):
"""Initialize and return a Q-table for the given environment."""
n_states = env.observation_space.n
n_actions = env.action_space.n
return np.zeros((n_states, n_actions))
def epsilon_greedy(env, Q_table, state, epsilon):
"""Epsilon-greedy action selection."""
if np.random.random() < epsilon:
return env.action_space.sample()
else:
return np.argmax(Q_table[state])
def sarsa_update(Q_table, state, action, reward, next_state, next_action, alpha, gamma):
"""Perform SARSA update on Q-table."""
Q_table[state, action] += alpha * (
reward + gamma * Q_table[next_state, next_action] - Q_table[state, action]
)
def train_sarsa(env, n_episodes=20000, alpha=0.1, gamma=0.99, epsilon=0.1):
"""Train the agent using SARSA algorithm."""
Q_table = initialize_q_table(env)
episode_rewards = []
episode_lengths = []
for episode in range(n_episodes):
state, _ = env.reset()
action = epsilon_greedy(env, Q_table, state, epsilon)
done = False
total_reward = 0
steps = 0
while not done:
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
next_action = epsilon_greedy(env, Q_table, next_state, epsilon)
sarsa_update(Q_table, state, action, reward, next_state, next_action, alpha, gamma)
state = next_state
action = next_action
total_reward += reward
steps += 1
episode_rewards.append(total_reward)
episode_lengths.append(steps)
if episode % 2000 == 0:
avg_reward = np.mean(episode_rewards[-1000:])
avg_length = np.mean(episode_lengths[-1000:])
print(f"Episode {episode}, Avg Reward: {avg_reward:.2f}, Avg Length: {avg_length:.2f}")
return Q_table, episode_rewards, episode_lengths
def plot_learning_curve(episode_rewards, episode_lengths):
"""Plot the learning curve."""
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(episode_rewards)
plt.title("Episode Rewards")
plt.xlabel("Episode")
plt.ylabel("Total Reward")
plt.subplot(1, 2, 2)
plt.plot(episode_lengths)
plt.title("Episode Lengths")
plt.xlabel("Episode")
plt.ylabel("Number of Steps")
plt.tight_layout()
plt.show()
def create_gif(frames, filename, fps=5):
"""Creates a GIF animation from a list of frames."""
clip = ImageSequenceClip(frames, fps=fps)
clip.write_gif(filename, fps=fps)
def run_episode(env, Q_table, epsilon=0):
"""Run a single episode using the learned Q-table."""
state, _ = env.reset()
done = False
total_reward = 0
frames = [env.render()]
while not done:
action = epsilon_greedy(env, Q_table, state, epsilon)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
frames.append(env.render())
total_reward += reward
state = next_state
return frames, total_reward
# Main execution
if __name__ == "__main__":
env = create_environment()
Q_table, episode_rewards, episode_lengths = train_sarsa(env)
plot_learning_curve(episode_rewards, episode_lengths)
# Run a single episode with the trained Q-table
frames, total_reward = run_episode(env, Q_table)
create_gif(frames, "trained_taxi.gif", fps=5)
print(f"Episode completed with total reward: {total_reward}")
env.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment