Created
January 24, 2023 04:59
-
-
Save jflam/fd6ed8b0f9cab9590797c2c6ebe7b6df to your computer and use it in GitHub Desktop.
Character Chess: GPT 3.5 and Stockfish team up!
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import base64, chess.svg, re, stockfish, openai, os | |
import streamlit as st | |
from streamlit_chat import message | |
from langchain.prompts import PromptTemplate | |
chess_move_regex = re.compile(r"([KQBNR]?[a-h]?[1-8]?x?[a-h][1-8])([+#]?|=[QBNR])") | |
CHARACTER = "Yoda" | |
house_prompt = PromptTemplate( | |
input_variables=["character", "my_move", "your_move"], | |
template=""" | |
You are playing a game of chess with me and responding to my move. Your moves | |
are being computed by the Stockfish chess engine. I want you to focus on | |
communicating with me in the voice of {character}. Don't comment on the state | |
of the game as you cannot possibly know the game as well as Stockfish does. | |
In my move is usually some commentary from me. Make sure that you respond to | |
whatever I'm saying in the voice of {character}. | |
After you respond to my commentary, make sure to explicitly mention your move | |
in your response. | |
Write at least 6 sentences in your response. Do not include any explanatory | |
text or my move in your reponse. Don't say what your name is either. Don't | |
quote your response. | |
My move: {my_move} | |
Your move: {your_move} | |
""") | |
ENVIRONMENT="EAST_AZURE_OPENAI" | |
openai.api_key = os.environ[f"{ENVIRONMENT}_API_KEY"] | |
openai.api_base = os.environ[f"{ENVIRONMENT}_ENDPOINT"] | |
openai.api_type = 'azure' | |
openai.api_version = '2022-12-01' # this may change in the future | |
DEPLOYMENT_ID = os.environ[f"{ENVIRONMENT}_DEPLOYMENT"] | |
st.set_page_config( | |
page_title="House Chess", | |
page_icon=":robot:" | |
) | |
f"## {CHARACTER} Chess" | |
def get_move(s): | |
match = chess_move_regex.search(s) | |
if match: | |
return match.group(0) | |
else: | |
raise ValueError("No move found in string") | |
def render_svg(svg): | |
"""Renders the given svg string.""" | |
b64 = base64.b64encode(svg.encode('utf-8')).decode("utf-8") | |
html = r'<img src="data:image/svg+xml;base64,%s"/>' % b64 | |
st.write(html, unsafe_allow_html=True) | |
if 'chess' not in st.session_state: | |
st.session_state.chess = stockfish.Stockfish() | |
st.session_state.board = chess.Board() | |
st.session_state.generated = [] | |
st.session_state.past = [] | |
st.session_state.error_message = "" | |
st.session_state.prompt = "" | |
board = st.session_state.board | |
sf = st.session_state.chess | |
def move(): | |
try: | |
my_prompt = st.session_state.input | |
my_move = get_move(my_prompt) | |
st.session_state["input"] = "" | |
if chess.Move.from_uci(my_move) in board.legal_moves: | |
with st.spinner("Thinking ..."): | |
board.push_san(my_move) | |
sf.set_fen_position(board.fen()) | |
stockfish_move = sf.get_best_move() | |
board.push_san(stockfish_move) | |
prompt = house_prompt.format(character=CHARACTER, | |
my_move=my_prompt, your_move=stockfish_move) | |
house_response = openai.Completion.create( | |
engine=DEPLOYMENT_ID, | |
prompt=prompt, | |
max_tokens=1000, | |
temperature=0.8).choices[0].text | |
st.session_state.past.append(my_prompt) | |
st.session_state.generated.append(house_response) | |
st.session_state.prompt = prompt | |
st.session_state.error_message = "" | |
else: | |
st.session_state.error_message = f"Invalid move: {my_move}! Try again." | |
except Exception as e: | |
st.session_state.error_message = str(e) | |
st.text_input("You: ","", key="input", on_change=move) | |
if st.session_state.error_message: | |
st.error(st.session_state.error_message, icon="🚨") | |
st.write(render_svg(chess.svg.board(board=board))) | |
with st.expander("Show prompt:"): | |
st.write(st.session_state.prompt) | |
if st.session_state['generated']: | |
for i in range(len(st.session_state['generated'])-1, -1, -1): | |
message(st.session_state["generated"][i], key=str(i)) | |
message(st.session_state['past'][i], is_user=True, key=str(i) + '_user') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment