Created
January 6, 2022 18:53
-
-
Save thomasahle/28e088ecd49579a103f73ce75742ccbb to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def play(r1, r2, replay_buffer): | |
privs = [game.make_priv(r1, 0).to(device), | |
game.make_priv(r2, 1).to(device)] | |
def play_inner(state): | |
cur = game.get_cur(state) | |
calls = game.get_calls(state) | |
assert cur == len(calls) % 2 | |
if calls and calls[-1] == game.LIE_ACTION: | |
prev_call = calls[-2] if len(calls) >= 2 else -1 | |
# If prev_call is good it mean we won (because our opponent called lie) | |
res = 1 if game.evaluate_call(r1, r2, prev_call) else -1 | |
else: | |
last_call = calls[-1] if calls else -1 | |
action = game.sample_action(privs[cur], state, last_call, args.eps) | |
new_state = game.apply_action(state, action) | |
# Just classic min/max stuff | |
res = -play_inner(new_state) | |
# Save the result from the perspective of both sides | |
replay_buffer.append((privs[cur], state, res)) | |
replay_buffer.append((privs[1 - cur], state, -res)) | |
return res | |
with torch.no_grad(): | |
state = game.make_state().to(device) | |
play_inner(state) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment