Created
June 27, 2018 13:38
-
-
Save BlGene/7a2585ed3726cd08ae536aea43493db4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import glob | |
import multiprocessing | |
from collections import deque | |
import matplotlib.pyplot as plt | |
import matplotlib.gridspec as gridspec | |
import pygame | |
from pdb import set_trace | |
import numpy as np | |
#env is in a different process | |
env_data_queue = multiprocessing.Queue() | |
class Viewer: | |
def __init__(self, transpose=True, zoom=None, video=False): | |
self.env_initialized = False | |
self.run_initialized = False | |
self.transpose = transpose | |
self.zoom = zoom | |
if video: | |
os.makedirs('./video', exist_ok=True) | |
files = glob.glob('./video/*.png') | |
for f in files: | |
os.remove(f) | |
self.frame_count = 0 | |
self.video = video | |
self.col_data = None | |
self.col_data = None | |
plt.ion() | |
self.fig = plt.figure(figsize = (2,2)) | |
gs = gridspec.GridSpec(2, 2) | |
gs.update(wspace=0.001, hspace=0.001) # set the spacing between axes. | |
self.col_ax = plt.subplot(gs[0,0]) | |
self.net_ax = plt.subplot(gs[0,1]) | |
self.plt_ax = plt.subplot(gs[1,:]) | |
self.col_ax.set_axis_off() | |
self.net_ax.set_axis_off() | |
self.plt_ax.set_axis_off() | |
plt.subplots_adjust(wspace=0.5, hspace=0, left=0, bottom=0, right=1, top=1) | |
# time series | |
num_plots = 2 | |
self.horizon_timesteps = 30 * 5 | |
self.t = 0 | |
self.cur_plot = [None for _ in range(num_plots)] | |
self.data = [deque(maxlen=self.horizon_timesteps) for _ in range(num_plots)] | |
def env_callback(self, env, draw=True): | |
obs = env._observation | |
env_data_queue.put(obs) | |
def run_callback(self, prev_obs, obs, actions, rew, masks, values, draw=True): | |
self.run_data = obs.copy() | |
print("Run callback: ",rew, np.max(obs)) | |
# time series | |
def data_callback(prev_obs, obs, actions, rew, masks, values): | |
return [rew[0], values[0],] | |
points = data_callback(prev_obs, obs, actions, rew, masks, values) | |
for point, data_series in zip(points, self.data): | |
data_series.append(point) | |
self.t += 1 | |
xmin, xmax = max(0, self.t - self.horizon_timesteps), self.t | |
for i, (plot,c,l) in enumerate(zip(self.cur_plot,['C1','C2'],['rew','val'])): | |
if plot is not None: | |
plot.remove() | |
#self.cur_plot[i] = self.plt_ax.scatter(range(xmin, xmax), list(self.data[i]),color='k') | |
self.cur_plot[i], = self.plt_ax.plot(range(xmin, xmax), list(self.data[i]),color=c,label=l) | |
self.plt_ax.set_xlim(xmin, xmax) | |
self.plt_ax.legend(loc='lower left') | |
self.draw() | |
def get_video_size(self, obs): | |
# helper function for draw | |
assert len(obs.shape) == 2 or (len(obs.shape) == 3 and obs.shape[2] in [1,3]) , "shape was {}".format(obs.shape) | |
if self.transpose: | |
video_size = obs.shape[1], obs.shape[0] | |
else: | |
video_size = obs.shape[0], obs.shape[1] | |
if self.zoom is not None: | |
video_size = int(video_size[0] * zoom), int(video_size[1] * zoom) | |
return video_size | |
def draw(self): | |
col_data = env_data_queue.get() | |
if self.env_initialized == True: | |
self.col_screen.set_data(col_data) | |
elif col_data is not None: | |
obs = col_data | |
video_size = self.get_video_size(obs) | |
self.col_screen = self.col_ax.imshow(obs, aspect='auto') | |
self.env_initialized = True | |
if self.run_initialized == True: | |
self.net_screen.set_data(self.run_data) | |
elif self.run_data is not None and not np.all(self.run_data == 0): | |
obs = self.run_data | |
video_size = self.get_video_size(obs) | |
self.net_screen = self.net_ax.imshow(obs, cmap='gray', aspect='auto') | |
self.run_initialized = True | |
self.fig.tight_layout() | |
self.fig.canvas.draw() | |
if self.video: | |
fn = './video/{0:03d}.png'.format(self.frame_count) | |
self.fig.savefig(fn, bbox_inches='tight', pad_inches=0) | |
self.frame_count += 1 | |
@staticmethod | |
def display_arr(screen, arr, transpose=True, video_size=(84,84)): | |
arr_min, arr_max = arr.min(), arr.max() | |
arr = 255.0 * (arr - arr_min) / (arr_max - arr_min) | |
pyg_img = pygame.surfarray.make_surface(arr.swapaxes(0, 1) if transpose else arr) | |
pyg_img = pygame.transform.scale(pyg_img, video_size) | |
screen.blit(pyg_img, (0,0)) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment