Last active
June 8, 2024 12:37
-
-
Save alper111/1feadf9e21cb2ef1548284bbe7d97ba1 to your computer and use it in GitHub Desktop.
A Sokoban game with crates having MNIST digits on top of them
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torchvision | |
import pygame | |
import numpy as np | |
import gymnasium as gym | |
class MNISTSokoban(gym.Env): | |
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 20} | |
def __init__(self, map_file: str = None, size: tuple[int, int] = None, max_crates: int = 5, max_steps=200, | |
render_mode: str = None, rand_digits: bool = False, rand_agent: bool = False, rand_x: bool = False): | |
assert map_file is not None or size is not None, "Either map_file or size must be provided" | |
self._map_file = map_file | |
self._size = size | |
self._max_crates = max_crates | |
self._max_steps = max_steps | |
self.render_mode = render_mode | |
self.rand_digits = rand_digits | |
self.rand_agent = rand_agent | |
self.rand_x = rand_x | |
self._shape = None | |
self._window = None | |
self._clock = None | |
self._map = None | |
self._digit_idx = None | |
self._agent_loc = None | |
self._delta = np.array([[0, 1], [-1, 0], [0, -1], [1, 0]]) | |
self._t = 0 | |
dataset = torchvision.datasets.MNIST(root="data", train=True, download=True) | |
self._data = dataset.data.numpy() | |
_labels = dataset.targets.numpy() | |
self._labels = {i: np.where(_labels == i)[0] for i in range(10)} | |
self.action_space = gym.spaces.Discrete(4) | |
def reset(self) -> tuple[np.ndarray, dict]: | |
self._init_agent_mark() | |
self._init_x_mark() | |
self._init_digits() | |
if self._map_file is not None: | |
self._map = self.read_map(self._map_file) | |
else: | |
self._map = self.generate_map(self._size, max_crates=self._max_crates) | |
self._shape = (len(self._map), len(self._map[0])) | |
shape = (self._shape[0]*32, self._shape[1]*32) | |
self.observation_space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8) | |
ax, ay = -1, -1 | |
for i in range(self._shape[0]): | |
for j in range(self._shape[1]): | |
if self._map[i][j][1] == "@": | |
ax, ay = i, j | |
break | |
self._agent_loc = np.array([ax, ay]) | |
self._t = 0 | |
obs = self._get_obs() | |
info = self._get_info() | |
if self.render_mode == "human": | |
self._render_frame() | |
return obs, info | |
def step(self, action: int) -> tuple[np.ndarray, float, bool, bool, dict]: | |
assert self._map is not None, "You must call reset() before calling step()" | |
pos = self._agent_loc | |
next_pos = pos + self._delta[action] | |
curr_bg, curr_tile = self._map[pos[0]][pos[1]] | |
next_bg, next_tile = self._map[next_pos[0]][next_pos[1]] | |
# the next tile is empty | |
if next_tile == " ": | |
self._map[pos[0]][pos[1]] = (curr_bg, " ") | |
self._map[next_pos[0]][next_pos[1]] = (next_bg, "@") | |
self._agent_loc = next_pos | |
# the next tile is a wall | |
elif next_tile == "#": | |
pass | |
# the next tile contains a crate | |
else: | |
# check whether the crate can be pushed | |
further_pos = next_pos + self._delta[action] | |
further_bg, further_tile = self._map[further_pos[0]][further_pos[1]] | |
if further_tile == " ": | |
self._map[pos[0]][pos[1]] = (curr_bg, " ") | |
self._map[next_pos[0]][next_pos[1]] = (next_bg, "@") | |
self._map[further_pos[0]][further_pos[1]] = (further_bg, next_tile) | |
self._agent_loc = next_pos | |
self._t += 1 | |
obs = self._get_obs() | |
info = self._get_info() | |
reward = self._get_reward() | |
terminated = reward > 1 - 1e-6 | |
truncated = (self._t >= self._max_steps) | |
return obs, reward, terminated, truncated, info | |
def render(self): | |
if self.render_mode == "rgb_array": | |
return self._render_frame() | |
def _init_x_mark(self): | |
self._x_corners = [ | |
np.random.randint(2, 9), | |
np.random.randint(2, 9), | |
np.random.randint(24, 31), | |
np.random.randint(24, 31), | |
np.random.randint(24, 31), | |
np.random.randint(2, 9), | |
np.random.randint(2, 9), | |
np.random.randint(24, 31) | |
] | |
def _init_agent_mark(self): | |
# random points for drawing the cross | |
self._a_corners = [ | |
np.random.randint(13, 20), | |
np.random.randint(2, 9), | |
np.random.randint(2, 9), | |
np.random.randint(24, 31), | |
np.random.randint(24, 31), | |
np.random.randint(24, 31) | |
] | |
def _init_digits(self): | |
self._digit_idx = np.zeros(10, dtype=np.int64) | |
for i in self._labels: | |
self._digit_idx[i] = np.random.choice(self._labels[i]) | |
def _render_frame(self): | |
canvas = pygame.Surface((self._shape[1]*32, self._shape[0]*32)) | |
canvas.fill((30, 30, 30)) | |
for i in range(self._shape[0]): | |
for j in range(self._shape[1]): | |
bg, tile = self._map[i][j] | |
if bg == "0": | |
if self.rand_digits: | |
digit_idx = np.random.choice(self._labels[0]) | |
else: | |
digit_idx = self._digit_idx[0] | |
digit = self._data[digit_idx] | |
digit = np.stack([digit]*3, axis=-1) | |
digit = pygame.surfarray.make_surface(np.transpose(digit, (1, 0, 2))) | |
bg_tile = pygame.transform.scale(digit, (32, 32)) | |
else: | |
bg_tile = pygame.Surface((32, 32)) | |
bg_tile.fill((30, 30, 30)) | |
canvas.blit(bg_tile, (j*32, i*32)) | |
if tile == "#": | |
color = (80, 80, 80) | |
rect = pygame.Rect(j*32, i*32, 32, 32) | |
pygame.draw.rect(canvas, color, rect) | |
elif tile == "@": | |
if self.rand_agent: | |
self._init_agent_mark() | |
color = (255, 255, 255) | |
width = 4 | |
pygame.draw.line(canvas, color, | |
(j*32+self._a_corners[0], i*32+self._a_corners[1]), | |
(j*32+self._a_corners[2], i*32+self._a_corners[3]), | |
width) | |
pygame.draw.line(canvas, color, | |
(j*32+self._a_corners[0], i*32+self._a_corners[1]), | |
(j*32+self._a_corners[4], i*32+self._a_corners[5]), | |
width) | |
pygame.draw.line(canvas, color, | |
(j*32+self._a_corners[2], i*32+self._a_corners[3]), | |
(j*32+self._a_corners[4], i*32+self._a_corners[5]), | |
width) | |
pygame.draw.circle(canvas, color, | |
(j*32+self._a_corners[0], i*32+self._a_corners[1]), | |
width//2) | |
pygame.draw.circle(canvas, color, | |
(j*32+self._a_corners[2], i*32+self._a_corners[3]), | |
width//2) | |
pygame.draw.circle(canvas, color, | |
(j*32+self._a_corners[4], i*32+self._a_corners[5]), | |
width//2) | |
elif tile != " ": | |
digit = int(self._map[i][j][1]) | |
if self.rand_digits: | |
digit_idx = np.random.choice(self._labels[digit]) | |
else: | |
digit_idx = self._digit_idx[digit] | |
digit = self._data[digit_idx] | |
digit = np.stack([digit]*3, axis=-1) | |
tile = pygame.surfarray.make_surface(np.transpose(digit, (1, 0, 2))) | |
# scale the tile to 32x32 | |
tile = pygame.transform.scale(tile, (32, 32)) | |
canvas.blit(tile, (j*32, i*32)) | |
if bg == "0": | |
if self.rand_x: | |
self._init_x_mark() | |
color = (255, 255, 255) | |
width = 4 | |
pygame.draw.line(canvas, color, | |
(j*32+self._x_corners[0], i*32+self._x_corners[1]), | |
(j*32+self._x_corners[2], i*32+self._x_corners[3]), | |
width) | |
pygame.draw.line(canvas, color, | |
(j*32+self._x_corners[4], i*32+self._x_corners[5]), | |
(j*32+self._x_corners[6], i*32+self._x_corners[7]), | |
width) | |
pygame.draw.circle(canvas, color, | |
(j*32+self._x_corners[0], i*32+self._x_corners[1]), | |
width//2) | |
pygame.draw.circle(canvas, color, | |
(j*32+self._x_corners[2], i*32+self._x_corners[3]), | |
width//2) | |
pygame.draw.circle(canvas, color, | |
(j*32+self._x_corners[4], i*32+self._x_corners[5]), | |
width//2) | |
pygame.draw.circle(canvas, color, | |
(j*32+self._x_corners[6], i*32+self._x_corners[7]), | |
width//2) | |
if self.render_mode == "human": | |
if self._window is None: | |
pygame.init() | |
pygame.display.init() | |
self._window = pygame.display.set_mode((self._shape[1]*32, self._shape[0]*32)) | |
if self._clock is None: | |
self._clock = pygame.time.Clock() | |
self._window.blit(canvas, canvas.get_rect()) | |
pygame.event.pump() | |
pygame.display.update() | |
self._clock.tick(self.metadata["render_fps"]) | |
return np.transpose(pygame.surfarray.array3d(canvas)[:, :, 0], (1, 0)) | |
def _get_obs(self) -> np.ndarray: | |
return self._render_frame() | |
def _get_info(self) -> dict: | |
return {"map": self._map} | |
def _get_reward(self) -> float: | |
n_crossed = 0 | |
n_total = 0 | |
for i in range(self._shape[0]): | |
for j in range(self._shape[1]): | |
bg, fg = self._map[i][j] | |
if bg == "0": | |
n_total += 1 | |
if (fg != "#" and fg != " " and fg != "@"): | |
n_crossed += 1 | |
return n_crossed / n_total | |
@property | |
def map(self) -> np.ndarray: | |
return self._map | |
@staticmethod | |
def read_map(map_file: str) -> list[list[str]]: | |
with open(map_file, "r") as f: | |
lines = f.readlines() | |
_map = [] | |
for line in lines: | |
row = [] | |
for x in line.strip(): | |
if x == "0": | |
row.append(("0", " ")) | |
else: | |
row.append((" ", x)) | |
_map.append(row) | |
return _map | |
@staticmethod | |
def generate_map(size: tuple[int, int] = (10, 10), max_crates: int = 5) -> list[list[str]]: | |
ni, nj = size | |
assert ni >= 3 and nj >= 3, "The size of the map must be at least 3x3" | |
total_middle_tiles = (ni-4)*(nj-4) | |
assert (2*max_crates+1) <= total_middle_tiles, \ | |
"The number of crates (together with their goals) must be less than the total non-edge empty tiles" | |
_map = [[(" ", " ") for _ in range(nj)] for _ in range(ni)] | |
for i in range(ni): | |
for j in range(nj): | |
if i == 0 or i == ni-1 or j == 0 or j == nj-1: | |
_map[i][j] = (" ", "#") | |
n = np.random.randint(1, max_crates+1) | |
digits = np.random.randint(1, 10, n) | |
locations = np.random.permutation((ni-4)*(nj-4))[:(2*n+1)] | |
for i, x_i in enumerate(digits): | |
di, dj = locations[i] // (nj-4) + 2, locations[i] % (nj-4) + 2 | |
_map[di][dj] = (" ", str(x_i)) | |
di, dj = locations[i+n] // (nj-4) + 2, locations[i+n] % (nj-4) + 2 | |
_map[di][dj] = ("0", " ") | |
ax, ay = locations[-1] // (nj-4) + 2, locations[-1] % (nj-4) + 2 | |
_map[ax][ay] = (" ", "@") | |
return _map | |
def example_map1(): | |
map_lines = [ | |
"#########\n", | |
"# 0 #\n", | |
"# 0 5 #\n", | |
"# 1 #\n", | |
"# @ 30 #\n", | |
"# #\n", | |
"#########" | |
] | |
with open("map1.txt", "w") as f: | |
f.writelines(map_lines) | |
if __name__ == "__main__": | |
# env = MNISTSokoban(map_file="map1.txt", max_crates=2, max_steps=200, render_mode="human") | |
env = MNISTSokoban(size=(7, 7), max_crates=3, max_steps=50, render_mode="human", | |
rand_digits=True, rand_agent=True, rand_x=True) | |
for _ in range(20): | |
env.reset() | |
done = False | |
while not done: | |
action = env.action_space.sample() | |
obs, rew, term, trun, info = env.step(action) | |
done = term or trun |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Screen.Recording.2024-06-07.at.5.23.26.PM.mov
Screen.Recording.2024-06-07.at.5.36.59.PM.mov
Screen.Recording.2024-06-07.at.5.48.02.PM.mov