Skip to content

Instantly share code, notes, and snippets.

@davidADSP
Created December 1, 2019 16:42
Show Gist options
  • Save davidADSP/fb2e29b666328c0271bc20b4bfbf8cbf to your computer and use it in GitHub Desktop.
Save davidADSP/fb2e29b666328c0271bc20b4bfbf8cbf to your computer and use it in GitHub Desktop.
class NetworkOutput(typing.NamedTuple):
value: float
reward: float
policy_logits: Dict[Action, float]
hidden_state: List[float]
class Network(object):
def initial_inference(self, image) -> NetworkOutput:
# representation + prediction function
return NetworkOutput(0, 0, {}, [])
def recurrent_inference(self, hidden_state, action) -> NetworkOutput:
# dynamics + prediction function
return NetworkOutput(0, 0, {}, [])
def get_weights(self):
# Returns the weights of this network.
return []
def training_steps(self) -> int:
# How many steps / batches the network has been trained for.
return 0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment