Created
June 5, 2020 16:29
-
-
Save wkentaro/071e1b2c5cfe7297a23c6a5c807f62f2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from abc import ABC, abstractmethod | |
| from typing import Any, List | |
| import warnings | |
| import numpy as np | |
| class Summary(object): | |
| def __init__(self, name: str, value: Any): | |
| self.name = name | |
| self.value = value | |
| class ScalarSummary(Summary): | |
| pass | |
| class HistogramSummary(Summary): | |
| pass | |
| class ImageSummary(Summary): | |
| pass | |
| class VideoSummary(Summary): | |
| def __init__(self, name: str, value: Any, fps: int = 30): | |
| super(VideoSummary, self).__init__(name, value) | |
| self.fps = fps | |
| class SummaryAccumulator: | |
| def __init__(self): | |
| self._accumulated = {} | |
| def __add__(self, summary: Summary): | |
| if not isinstance(summary, ScalarSummary): | |
| warnings.warn("Summary class %s cannot be accumulated" % summary) | |
| return | |
| if summary.name in self._accumulated: | |
| if type(self._accumulated[summary.name]) != type(summary): | |
| raise RuntimeError("Trying to accumulate different typed " | |
| "summaries with the same name") | |
| self._accumulated[summary.name].append(summary) | |
| else: | |
| self._accumulated[summary.name] = [summary] | |
| def summaries(self): | |
| ret = [] | |
| for name, summaries in self._accumulated.items(): | |
| if isinstance(summaries[0], ScalarSummary): | |
| ret_i = ScalarSummary( | |
| name, np.mean([summary.value for summary in summaries])) | |
| else: | |
| raise TypeError("Cannot generate summaries for: %s" % | |
| type(summaries[0])) | |
| ret.append(ret_i) | |
| return ret | |
| class ActResult(object): | |
| def __init__(self, action: Any, observation_elements: dict = None, | |
| extra_elements: dict = None): | |
| self.action = action | |
| self.observation_elements = {} if observation_elements is None else observation_elements | |
| self.extra_elements = {} if extra_elements is None else extra_elements | |
| class Agent(ABC): | |
| @abstractmethod | |
| def build(self, training: bool, device=None) -> None: | |
| pass | |
| @abstractmethod | |
| def update(self, step: int, replay_sample: dict) -> dict: | |
| pass | |
| @abstractmethod | |
| def act(self, step: int, observation: dict, deterministic: bool) -> ActResult: | |
| # returns dict of values that get put in the replay. | |
| # One of these must be 'action'. | |
| pass | |
| def reset(self) -> None: | |
| pass | |
| @abstractmethod | |
| def summaries(self) -> List[Summary]: | |
| pass | |
| @abstractmethod | |
| def load_weights(self, savedir: str) -> None: | |
| pass | |
| @abstractmethod | |
| def save_weights(self, savedir: str) -> None: | |
| pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment