Created
April 25, 2023 11:32
-
-
Save cheadrian/4e59c9d24fe6b3e9db6adaf034d7387a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from collections import namedtuple, deque | |
from itertools import count | |
import random, datetime, os, copy, glob, math, time, asyncio, base64, shutil | |
import cv2 | |
import gym, ray | |
from gym import Env, spaces | |
from gym.spaces import Box | |
from ray.rllib.algorithms import ppo | |
from ray.air.config import RunConfig, ScalingConfig, CheckpointConfig | |
from ray.air.checkpoint import Checkpoint | |
from ray.train.rl import RLTrainer, RLCheckpoint, RLPredictor | |
from ray.tune.tuner import Tuner | |
from pyppeteer import launch | |
import nest_asyncio | |
nest_asyncio.apply() | |
from ray.tune.registry import register_env | |
import IPython | |
from google.colab.patches import cv2_imshow | |
from IPython.display import clear_output | |
class Snake(Env): | |
last_base64_canvas = "" | |
def __init__(self, env_config): | |
super(Snake, self).__init__() | |
self.i_id = random.randint(0,100) | |
self.loop = asyncio.get_event_loop() | |
self.page = self.loop.run_until_complete(self.get_browser_page()) | |
self.observation_shape = (168, 168, 3) | |
self.canvas = self.get_canvas() | |
self.observation_space = spaces.Box(low = 0, | |
high = 255, | |
shape = self.observation_shape, | |
dtype = np.uint8) | |
self.action_space = spaces.Discrete(5,) | |
self.score = 0 | |
self.detect_score = 0 | |
def readb64(self, uri): | |
encoded_data = uri.split(',')[1] | |
nparr = np.frombuffer(base64.b64decode(encoded_data), np.uint8) | |
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
return img | |
async def do_async_action(self, input_btn): | |
if(input_btn == 0): | |
await self.page.evaluate("rightBtn.click()") | |
if(input_btn == 1): | |
await self.page.evaluate("leftBtn.click()") | |
if(input_btn == 2): | |
await self.page.evaluate("downBtn.click()") | |
if(input_btn == 3): | |
await self.page.evaluate("upBtn.click()") | |
def do_action(self, input_btn): | |
return self.loop.run_until_complete(self.do_async_action(input_btn)) | |
async def start_play(self): | |
is_playbtn_hidden = await self.page.evaluate("playBtn.offsetParent === null") | |
if not is_playbtn_hidden: | |
await self.page.evaluate("playBtn.click()") | |
else: | |
await self.page.reload({ "waitUntil": ["networkidle0", "domcontentloaded"] }) | |
await self.page.evaluate("playBtn.click()") | |
async def check_done(self): | |
is_playbtn_hidden = await self.page.evaluate("playBtn.offsetParent === null") | |
return not is_playbtn_hidden | |
async def get_async_canvas(self): | |
base64_canvas = await self.page.evaluate("document.querySelector('#canvas').toDataURL();") | |
while base64_canvas == self.last_base64_canvas: | |
base64_canvas = await self.page.evaluate("document.querySelector('#canvas').toDataURL();") | |
if await self.check_done(): | |
break | |
self.last_base64_canvas = base64_canvas | |
image = self.readb64(base64_canvas) | |
return image | |
async def get_async_score(self): | |
return await self.page.evaluate("scorutz") | |
async def reset_async_score(self): | |
await self.page.evaluate("scorutz = 0") | |
def reset_score(self): | |
self.loop.run_until_complete(self.reset_async_score()) | |
def get_canvas(self): | |
return self.loop.run_until_complete(self.get_async_canvas()) | |
async def get_browser_page(self): | |
browser = await launch(headless=True, executablePath="/usr/lib/chromium-browser/chromium-browser",args=['--no-sandbox']) | |
page = await browser.newPage() | |
await page.goto('https://localhost/snake/Jungle%20Snake.html') | |
return page | |
def reset(self): | |
self.score = 0 | |
self.detect_score = 0 | |
self.reset_score() | |
self.loop.run_until_complete(self.start_play()) | |
time.sleep(0.3) | |
self.canvas = self.get_canvas() | |
return self.canvas | |
def step(self, action): | |
done = False | |
reward = 0 | |
assert self.action_space.contains(action), "Invalid Action" | |
did_action = self.do_action(action) | |
sz_sum = self.get_score() | |
if(sz_sum != self.detect_score): | |
reward = +4 | |
self.score += 1 | |
self.detect_score = sz_sum | |
reward += 1 | |
self.canvas = self.get_canvas() | |
done = self.loop.run_until_complete(self.check_done()) | |
if done: | |
reward -= 3 | |
return self.canvas, reward, done, {} | |
def close(self): | |
cv2.destroyAllWindows() | |
def get_action_meanings(self): | |
return {0: "Right", 1: "Left", 2: "Down", 3: "Up", 4: "Do Nothing"} | |
def get_score(self): | |
return self.loop.run_until_complete(self.get_async_score()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment