Created
January 6, 2022 18:54
-
-
Save thomasahle/d4c467fd64d3e9491cde2949741781cd to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def play(r1, r2, replay_buffer): | |
privs = [game.make_priv(r1, 0), game.make_priv(r2, 1)] | |
def play_inner(state): | |
cur = game.get_cur(state) | |
calls = game.get_calls(state) | |
assert cur == len(calls) % 2 | |
if calls and calls[-1] == game.LIE_ACTION: | |
prev_call = calls[-2] if len(calls) >= 2 else -1 | |
# If prev_call is good it mean we won (because our opponent called lie) | |
res = 1 if game.evaluate_call(r1, r2, prev_call) else -1 | |
else: | |
last_call = calls[-1] if calls else -1 | |
action = game.sample_action(privs[cur], state, last_call, args.eps) | |
new_state = game.apply_action(state, action) | |
# Just classic min/max stuff | |
res = -play_inner(new_state) | |
# Save the result from the perspective of both sides | |
replay_buffer.append((privs[cur], state, res)) | |
replay_buffer.append((privs[1 - cur], state, -res)) | |
return res | |
with torch.no_grad(): | |
state = game.make_state() | |
play_inner(state) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment