Last active
December 11, 2023 01:02
-
-
Save 903124/cfbefa24da95e2316e0d5e8ef8ed360d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import chess | |
import re | |
import outlines.text.generate as generate | |
import outlines.models as models | |
import chess.engine | |
import chessboard.display | |
import os | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
model_id = "BlueSunflower/Pythia-160M-chess" | |
model = models.transformers(model_id) | |
ENGINE_PATH = "/path/to/engine" | |
def generate_regex(board): | |
legal_moves = list(board.legal_moves) | |
move_strings = [board.san(move) for move in legal_moves] | |
# Remove + and # from the move strings | |
move_strings = [re.sub(r"[+#]", "", move) for move in move_strings] | |
regex_pattern = "|".join(re.escape(move) for move in move_strings) | |
regex_pattern = f"{regex_pattern}" | |
return regex_pattern | |
board = chess.Board("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1") | |
prompt = """Score: 1-0 WhiteElo: 1600 BlackElo: 1600 Timecontrol: 1800+0 Moves: 1.""" | |
game_board = chessboard.display.start(board.fen()) | |
# Initialize the engine | |
engine = chess.engine.SimpleEngine.popen_uci(ENGINE_PATH) | |
engine.configure({"UCI_LimitStrength": True,"UCI_Elo": 1320}) | |
while not board.is_game_over(): | |
if board.turn == chess.WHITE: | |
# LLM's turn (White) | |
regex_pattern = generate_regex(board) | |
guided = generate.regex(model, regex_pattern, max_tokens=10)(prompt) | |
try: | |
move = board.parse_san(guided) | |
except ValueError: | |
print(f"Invalid move: {guided}") | |
break | |
else: | |
# Engine's turn (Black) | |
result = engine.play(board, chess.engine.Limit(time=0.6)) | |
move = result.move | |
def get_max_turn_number(prompt): | |
matches = re.findall(r"(\d+)\.", prompt) | |
if matches: | |
turn_numbers = [int(match) for match in matches] | |
return max(turn_numbers) | |
return None | |
def get_current_turn_moves(prompt): | |
max_turn = get_max_turn_number(prompt) | |
if max_turn is not None: | |
segments = re.split(r"\d+\.", prompt) | |
# The last segment corresponds to the moves of the largest turn number | |
moves_segment = segments[-1].strip() | |
moves = moves_segment.split() | |
return max_turn, moves | |
return None, None | |
turn_number, current_turn_moves = get_current_turn_moves(prompt) | |
if len(current_turn_moves) == 0 : # It's White's turn | |
prompt += board.san(move) + " " | |
else: | |
turn_number += 1 | |
prompt += board.san(move) + " " + str(turn_number) + "." | |
# # If turn number is multiple of 8 add new line break | |
# if turn_number % 8 == 0: | |
# prompt += "\n" | |
board.push(move) | |
print(prompt) | |
chessboard.display.update(board.fen(), game_board) | |
# engine.quit() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment