Skip to content

Instantly share code, notes, and snippets.

@hartikainen
Last active February 1, 2020 14:14
Show Gist options
  • Save hartikainen/5e9b6bf55b4623e8e6fd561338e02464 to your computer and use it in GitHub Desktop.
Save hartikainen/5e9b6bf55b4623e8e6fd561338e02464 to your computer and use it in GitHub Desktop.
import random
import time
import tensorflow as tf
from ray import tune
from wandb.ray import WandbLogger
class MyLogger(tune.logger.Logger):
def _init(self):
self._config = None
# my.init(**self.config.get("env_config", {}).get("my", {}))
breakpoint()
pass
def on_result(self, result):
config = result.get("config")
breakpoint()
pass
def close(self):
pass
class ExperimentRunner(tune.Trainable):
def _setup(self, variant):
self.timestep = 0
def _train(self):
self.timestep += 1
v = tf.tanh(float(self.timestep) / self.config.get("width", 1)).numpy()
v *= self.config.get("height", 1)
time.sleep(tf.random.uniform((), minval=0.1, maxval=2.0))
if self.config['width'] == 10 and 2 < self.timestep:
raise ValueError("Intentional exception to see if wandb fails.")
return {"episode_reward_mean": v}
def main():
tune.run(
ExperimentRunner,
resources_per_trial={'cpu': 2},
loggers=[tune.logger.UnifiedLogger, WandbLogger],
num_samples=1,
stop={"training_iteration": 10},
config={
"width": tune.grid_search([10, 20, 30]),
"height": tune.sample_from(lambda spec: int(100 * random.random())),
'env_config': {"wandb": {"project": "my-project-name"}},
},
local_dir='/tmp/wandb-test',
)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment