This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from functorch import vmap, grad | |
from torch.autograd import Function | |
sigmoid = torch.sigmoid | |
sigmoid_grad = vmap(vmap(grad(sigmoid))) | |
class TopK(Function): | |
@staticmethod | |
def forward(ctx, xs, k): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[nav] In [48]: table = torch.arange(12, dtype=torch.float32).reshape(4,3) | |
[ins] In [49]: new_table = torch.zeros(4, 3) | |
[ins] In [50]: index = torch.tensor([1,1,0,3]) | |
[ins] In [51]: index2 = index.unsqueeze(1).expand(4,3) | |
[ins] In [52]: table | |
Out[52]: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def play(r1, r2, replay_buffer): | |
privs = [game.make_priv(r1, 0), game.make_priv(r2, 1)] | |
def play_inner(state): | |
cur = game.get_cur(state) # Current player id | |
calls = game.get_calls(state) # Bets made by player so far | |
if calls and calls[-1] == game.LIE_ACTION: | |
prev_call = calls[-2] if len(calls) >= 2 else -1 | |
# If prev_call is good it mean we won (because our opponent called lie) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def play(r1, r2, replay_buffer): | |
privs = [game.make_priv(r1, 0), game.make_priv(r2, 1)] | |
def play_inner(state): | |
cur = game.get_cur(state) | |
calls = game.get_calls(state) | |
assert cur == len(calls) % 2 | |
if calls and calls[-1] == game.LIE_ACTION: | |
prev_call = calls[-2] if len(calls) >= 2 else -1 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def play(r1, r2, replay_buffer): | |
privs = [game.make_priv(r1, 0).to(device), | |
game.make_priv(r2, 1).to(device)] | |
def play_inner(state): | |
cur = game.get_cur(state) | |
calls = game.get_calls(state) | |
assert cur == len(calls) % 2 | |
if calls and calls[-1] == game.LIE_ACTION: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def update(state, t): | |
pi = compute_policy(state, t) | |
score = 0 | |
for i in actions(state): | |
score_i = update(state + i) | |
score += pi[i] * score_i | |
state.mean_score = (state.mean_score * t + score)/(t + 1) | |
return score | |
def compute_policy(state): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def update(state, t): | |
pi = compute_policy(state, t) | |
score = 0 | |
for i in actions(state): | |
score_i = update(state + action) | |
score += pi[i] * score_i | |
state.mean_score = (state.mean_score * t + score)/(t + 1) | |
return score | |
def compute_policy(state): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def play(r1, r2, replay_buffer): | |
privs = [game.make_priv(r1, 0).to(device), | |
game.make_priv(r2, 1).to(device)] | |
def play_inner(state): | |
cur = game.get_cur(state) | |
calls = game.get_calls(state) | |
assert cur == len(calls) % 2 | |
if calls and calls[-1] == game.LIE_ACTION: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ParameterDeque(nn.Module): | |
def __init__(self) -> None: | |
super(ParameterDeque, self).__init__() | |
self.left = 0 | |
self.right = 0 # Points at the first non-existing element | |
def _convert_idx(self, idx): | |
"""Get the absolute index for the list of modules""" | |
idx = operator.index(idx) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def sinkhorn_forward(C, mu, nu, epsilon, max_iter): | |
bs, n, k_ = C.size() | |
v = torch.ones([bs, 1, k_])/(k_) | |
G = torch.exp(-C/epsilon) | |
if torch.cuda.is_available(): | |
v = v.cuda() | |
for i in range(max_iter): | |
u = mu/(G*v).sum(-1, keepdim=True) |