Skip to content

Instantly share code, notes, and snippets.

@wkentaro
Created June 5, 2020 16:29
Show Gist options
  • Save wkentaro/071e1b2c5cfe7297a23c6a5c807f62f2 to your computer and use it in GitHub Desktop.
Save wkentaro/071e1b2c5cfe7297a23c6a5c807f62f2 to your computer and use it in GitHub Desktop.
from abc import ABC, abstractmethod
from typing import Any, List
import warnings
import numpy as np
class Summary(object):
def __init__(self, name: str, value: Any):
self.name = name
self.value = value
class ScalarSummary(Summary):
pass
class HistogramSummary(Summary):
pass
class ImageSummary(Summary):
pass
class VideoSummary(Summary):
def __init__(self, name: str, value: Any, fps: int = 30):
super(VideoSummary, self).__init__(name, value)
self.fps = fps
class SummaryAccumulator:
def __init__(self):
self._accumulated = {}
def __add__(self, summary: Summary):
if not isinstance(summary, ScalarSummary):
warnings.warn("Summary class %s cannot be accumulated" % summary)
return
if summary.name in self._accumulated:
if type(self._accumulated[summary.name]) != type(summary):
raise RuntimeError("Trying to accumulate different typed "
"summaries with the same name")
self._accumulated[summary.name].append(summary)
else:
self._accumulated[summary.name] = [summary]
def summaries(self):
ret = []
for name, summaries in self._accumulated.items():
if isinstance(summaries[0], ScalarSummary):
ret_i = ScalarSummary(
name, np.mean([summary.value for summary in summaries]))
else:
raise TypeError("Cannot generate summaries for: %s" %
type(summaries[0]))
ret.append(ret_i)
return ret
class ActResult(object):
def __init__(self, action: Any, observation_elements: dict = None,
extra_elements: dict = None):
self.action = action
self.observation_elements = {} if observation_elements is None else observation_elements
self.extra_elements = {} if extra_elements is None else extra_elements
class Agent(ABC):
@abstractmethod
def build(self, training: bool, device=None) -> None:
pass
@abstractmethod
def update(self, step: int, replay_sample: dict) -> dict:
pass
@abstractmethod
def act(self, step: int, observation: dict, deterministic: bool) -> ActResult:
# returns dict of values that get put in the replay.
# One of these must be 'action'.
pass
def reset(self) -> None:
pass
@abstractmethod
def summaries(self) -> List[Summary]:
pass
@abstractmethod
def load_weights(self, savedir: str) -> None:
pass
@abstractmethod
def save_weights(self, savedir: str) -> None:
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment