Skip to content

Instantly share code, notes, and snippets.

@903124
Last active December 11, 2023 01:02
Show Gist options
  • Save 903124/cfbefa24da95e2316e0d5e8ef8ed360d to your computer and use it in GitHub Desktop.
Save 903124/cfbefa24da95e2316e0d5e8ef8ed360d to your computer and use it in GitHub Desktop.
import chess
import re
import outlines.text.generate as generate
import outlines.models as models
import chess.engine
import chessboard.display
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
model_id = "BlueSunflower/Pythia-160M-chess"
model = models.transformers(model_id)
ENGINE_PATH = "/path/to/engine"
def generate_regex(board):
legal_moves = list(board.legal_moves)
move_strings = [board.san(move) for move in legal_moves]
# Remove + and # from the move strings
move_strings = [re.sub(r"[+#]", "", move) for move in move_strings]
regex_pattern = "|".join(re.escape(move) for move in move_strings)
regex_pattern = f"{regex_pattern}"
return regex_pattern
board = chess.Board("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1")
prompt = """Score: 1-0 WhiteElo: 1600 BlackElo: 1600 Timecontrol: 1800+0 Moves: 1."""
game_board = chessboard.display.start(board.fen())
# Initialize the engine
engine = chess.engine.SimpleEngine.popen_uci(ENGINE_PATH)
engine.configure({"UCI_LimitStrength": True,"UCI_Elo": 1320})
while not board.is_game_over():
if board.turn == chess.WHITE:
# LLM's turn (White)
regex_pattern = generate_regex(board)
guided = generate.regex(model, regex_pattern, max_tokens=10)(prompt)
try:
move = board.parse_san(guided)
except ValueError:
print(f"Invalid move: {guided}")
break
else:
# Engine's turn (Black)
result = engine.play(board, chess.engine.Limit(time=0.6))
move = result.move
def get_max_turn_number(prompt):
matches = re.findall(r"(\d+)\.", prompt)
if matches:
turn_numbers = [int(match) for match in matches]
return max(turn_numbers)
return None
def get_current_turn_moves(prompt):
max_turn = get_max_turn_number(prompt)
if max_turn is not None:
segments = re.split(r"\d+\.", prompt)
# The last segment corresponds to the moves of the largest turn number
moves_segment = segments[-1].strip()
moves = moves_segment.split()
return max_turn, moves
return None, None
turn_number, current_turn_moves = get_current_turn_moves(prompt)
if len(current_turn_moves) == 0 : # It's White's turn
prompt += board.san(move) + " "
else:
turn_number += 1
prompt += board.san(move) + " " + str(turn_number) + "."
# # If turn number is multiple of 8 add new line break
# if turn_number % 8 == 0:
# prompt += "\n"
board.push(move)
print(prompt)
chessboard.display.update(board.fen(), game_board)
# engine.quit()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment