This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys, os, shutil, pdb, random | |
from tqdm import tqdm | |
def q(text = ''): | |
print(f'>{text}<') | |
sys.exit() | |
from environment import TicTacToe | |
from agent import QLearningAgent, Hoooman | |
import config as cfg | |
from config import display_board |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys, os, shutil, pdb, random | |
from tqdm import tqdm | |
def q(text = ''): | |
''' | |
a function that exits the code after printing a message. used for dubugging purposes | |
''' | |
print(f'>{text}<') # f-strings work only with python3 | |
sys.exit() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
summary_dir = 'summary' | |
num_episodes = 500000 | |
display = False # boolean for diplaying/printing the Tic-Tac-Toe board on the terminal. It is suggested to set it to False for training purposes | |
# exploration-exploitation trade-off factor | |
epsilon = 0.4 # must be a real number between (0,1) | |
# learning-rate | |
alpha = 0.3 # must be a real number between (0,1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def display_board(board, action, playerID, player1, player2, reward, done, possible_actions, training = True, episode_reward_player1=None, episode_reward_player2=None): | |
''' | |
prints out the Tic-Tac-Toe board in the terminal. | |
prints the action taken by the players, the reward they recieved and the status of the game (Done -> True or False) | |
prints if either of the players have won or lost the game or if it is a tied between the players | |
prints all the possible next actions if the training argument is set to True | |
''' | |
print('\n') | |
for i in range(3): | |
print(' '.join(board[i*3:(i+1)*3])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random, pickle | |
import config as cfg | |
class QLearningAgent: | |
def __init__(self, name, epsilon = cfg.epsilon, alpha = cfg.alpha, gamma = cfg.gamma): | |
self.name = name | |
self.epsilon = epsilon # exploration-exploiataion trade-off factor | |
self.alpha = alpha # learning-rate | |
self.gamma = gamma # discount-factor | |
self.Q = {} # Q-Table |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class TicTacToe: | |
def __init__(self): | |
''' | |
the environment starts with 9 empty spaces representing a board of Tic-Tac-Toe | |
''' | |
self.board = ['_']*9 # the initial blank board | |
self.done = False # done = True means the game has ended | |
def reset(self): | |
''' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
# A function for exiting the script after printing a message | |
def q(text = ''): | |
print(f'>{text}<') | |
sys.exit() | |
import argparse, os | |
# Desktop path |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cv2 | |
import torch.nn.functional as F | |
# Making a 'predict' function which would take the 'model' and the path of the 'test image' as inputs, and predict the class that the test image belongs to. | |
def predict(model, test_img_path): | |
img = cv2.imread(test_img_path) | |
# Visualizing the test image | |
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# We will now freeze the 'layer4' and train just the 'fc' layer of the model for 2 more epochs | |
for name, param in model.named_parameters(): | |
if 'layer4' in name: | |
param.requires_grad = False # layer4 parameters would not get trained now | |
# Define the new learning rate and the new optimizer which would contain only the parameters with requires_grad = True | |
lr = 0.0003 | |
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = lr) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
loader = {'train': train_loader, 'val': val_loader} | |
epochs = 5 | |
log_interval = 2 | |
# Let's train the model for 5 epochs ! | |
train_losses, val_losses, batch_train_losses, batch_val_losses = trainer(loader, model, loss_fn, optimizer, epochs = epochs, log_interval = log_interval) | |
# Ploting the epoch losses | |
plt.plot(train_losses) |
NewerOlder