Skip to content

Instantly share code, notes, and snippets.

snake = t.zeros((32, 32),
snake[0, :3] = T([1, 2, -1])
fig, ax = plt.subplots(1, 1)
img = ax.imshow(snake)
action = {'val': 1}
action_dict = {'a': 0, 'd': 2}
fig.canvas.mpl_connect('key_press_event', lambda e:
action.__setitem__('val', action_dict[e.key]))
def do(snake: t.Tensor, action: int):
positions = snake.flatten().topk(2)[1]
[pos_cur, pos_prev] = [T(unravel(x, snake.shape)) for x in positions]
rotation = T([[0, -1], [1, 0]]).matrix_power(3 + action)
pos_next = (pos_cur + (pos_cur - pos_prev) @ rotation) % T(snake.shape)
if (snake[tuple(pos_next)] > 0).any():
return (snake[tuple(pos_cur)] - 2).item()
if snake[tuple(pos_next)] == -1:
pos_food = (snake == 0).flatten().to(t.float).multinomial(1)[0]
positions = snake.flatten().topk(2)[1]
[pos_cur, pos_prev] = [T(unravel(x, snake.shape)) for x in positions]
positions = snake.flatten().topk(2)[1]
[pos_cur, pos_prev] = [T(unravel(x, snake.shape)) for x in positions]
eliasffyksen /
Created August 8, 2022 18:31
Imports and function decleration
import torch as t
from torch import tensor as T
from numpy import unravel_index as unravel
import matplotlib.pyplot as plt
def do(snake, action):
'''This is where the magic happens :) '''