-
-
Save kjoth/08f1eef89dc91cad4ebe8d5e6af562c1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import chess | |
import chess.engine | |
import os | |
import dspy | |
from pydantic import BaseModel, Field | |
from dspy.functional import TypedPredictor | |
from dotenv import load_dotenv | |
load_dotenv() | |
llm = dspy.OpenAI(model='gpt-4', api_key=os.getenv("OPENAI_API_KEY")) | |
dspy.settings.configure(lm=llm) | |
class NextMove(BaseModel): | |
move: str = Field(..., description="Best next move to win the chess game. It should be in Standard Algebraic Notation") | |
reasoning: str = Field(..., description="Reasoning explaining why the move is the best") | |
class ChessAgentSignature(dspy.Signature): | |
"""Generate the best next move in the chess game given the current board state, | |
the history of moves so far, the list of legal moves, and the feedback | |
on the move previously generated. | |
""" | |
board_state = dspy.InputField(desc="The current state of the chess board") | |
legal_moves = dspy.InputField(desc="list of legal moves") | |
history = dspy.InputField(desc="history of moves so far in the game") | |
feedback = dspy.InputField(desc="feedback on the move previously generated") | |
next_move: NextMove = dspy.OutputField() | |
# create a chess agent | |
chess_agent = TypedPredictor(ChessAgentSignature) | |
# Initialize a chess board and the chess engine | |
engine = chess.engine.SimpleEngine.popen_uci("/opt/homebrew/Cellar/stockfish/16/bin/stockfish") | |
def play_game(): | |
moves = [] # Variable to store the list of moves | |
board = chess.Board() | |
def get_agent_move(board): | |
feedback = "" | |
while True: | |
response = chess_agent(board_state=str(board), | |
legal_moves=str(board.legal_moves), | |
history=str(moves), | |
feedback=feedback) | |
next_move = response.next_move.move | |
try: | |
move = board.parse_san(next_move) | |
if move in board.legal_moves: | |
return move | |
else: | |
feedback = f"Agent's generated move {move} is not valid currently. Should be a move in this list: str({board.legal_moves})" | |
except Exception as e: | |
feedback = f"Failed to parse the Agent's generated move. {e} Retrying..." | |
while not board.is_game_over(): | |
if board.turn: # True for white's turn, False for black's turn | |
result = engine.play(board, chess.engine.Limit(time=2.0)) | |
board.push(result.move) | |
moves.append(result.move.uci()) # Store UCI move in the list | |
else: | |
move = get_agent_move(board) | |
board.push(move) | |
moves.append(move.uci()) # Store UCI move in the list | |
print(board) | |
print("\n\n") | |
# Check the result of the game | |
winner = "" | |
if board.is_checkmate(): | |
if board.turn: | |
winner = "Black" | |
else: | |
winner = "White" | |
elif board.is_stalemate() or board.is_insufficient_material() or board.is_seventyfive_moves() or board.is_fivefold_repetition() or board.is_variant_draw(): | |
winner = "Draw" | |
# Return the result | |
if winner == "Black": | |
return "Agent wins by checkmate." | |
elif winner == "White": | |
return "Stockfish wins by checkmate." | |
else: | |
return "The game is a draw." | |
# Number of games to play | |
n_games = 1 | |
# Initialize a dictionary to store the results | |
results = {"Agent wins": 0, "Stockfish wins": 0, "Draw": 0} | |
# Run the game n times | |
for i in range(n_games): | |
print(f"Starting game {i+1}...") | |
result = play_game() | |
print(result) | |
# Update the results dictionary based on the outcome of the game | |
if "Agent wins" in result: | |
results["Agent wins"] += 1 | |
elif "Stockfish wins" in result: | |
results["Stockfish wins"] += 1 | |
else: | |
results["Draw"] += 1 | |
print(f"Game {i+1} finished.\n\n") | |
# Print the final results | |
print("Final results after playing", n_games, "games:") | |
print("Agent won:", results["Agent wins"], "games") | |
print("Stockfish won:", results["Stockfish wins"], "games") | |
print("Draw:", results["Draw"], "games") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment