Created
September 17, 2018 21:15
-
-
Save arthurmensch/f6a80691662e59f10283205eb15762ce to your computer and use it in GitHub Desktop.
MCTS + dask
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import threading as thr | |
from concurrent.futures import ThreadPoolExecutor | |
from concurrent.futures import as_completed as thr_as_completed | |
from queue import Queue as ThrQueue | |
from time import sleep | |
import numpy as np | |
import tornado | |
from distributed import Client, Queue, as_completed, Pub, Sub, get_client, \ | |
get_worker | |
class _MCTS: | |
"""Dummy class to be replaced by cython implementation""" | |
def __init__(self, max_eval): | |
self.max_eval = max_eval | |
self.state_counter = 0 | |
self.eval_counter = 0 | |
self.turn_counter = 0 | |
self.exploring = thr.Event() | |
self.backuping = thr.Event() | |
self.thread_pool = ThreadPoolExecutor(max_workers=2) | |
self.state_q = ThrQueue() | |
self.eval_q = ThrQueue() | |
def _explore(self): | |
self.exploring.set() | |
while self.state_counter < self.max_eval: | |
sleep(.001) | |
self.state_q.put(np.zeros((2048, 11, 11), dtype=np.uint8)) | |
self.state_counter += 1 | |
self.exploring.clear() | |
def _backup(self): | |
self.backuping.set() | |
while self.eval_counter < self.max_eval: | |
self.eval_q.get() | |
sleep(.001) | |
self.eval_counter += 1 | |
self.backuping.clear() | |
def get(self): | |
res = self.state_q.get() | |
if res is None: | |
self.exploring.wait() | |
return self.state_q.get() | |
def put(self, eval): | |
self.backuping.wait() | |
self.eval_q.put(eval) | |
def act(self): | |
self.turn_counter += 1 | |
return Record(np.zeros((11, 11), dtype=np.uint8), np.zeros((4, 7), dtype=np.float32)) | |
def grow(self): | |
explore_future = self.thread_pool.submit(self._explore) | |
backup_future = self.thread_pool.submit(self._backup) | |
for future in thr_as_completed((explore_future, backup_future)): | |
try: | |
future.result() | |
except Exception as e: | |
raise ChildProcessError from e | |
self.eval_counter = 0 | |
self.state_counter = 0 | |
def __repr__(self): | |
return (f'turn {self.turn_counter}, ' | |
f'state/eval {self.state_counter}/{self.eval_counter} ' | |
f'buffers state/eval {self.state_q.qsize()}/{self.eval_q.qsize()}') | |
class Record: | |
def __init__(self, state, target): | |
self.state = state | |
self.target = target | |
class Player: | |
@classmethod | |
def start(cls, *args, **kwargs): | |
player = cls(*args, **kwargs) | |
player.loop() | |
return player.watch() | |
def __init__(self, state_q, eval_q, train_q): | |
self.thread_pool = ThreadPoolExecutor(max_workers=5) | |
self.alive = thr.Event() | |
self.mcts = _MCTS(max_eval=500, ) | |
self.state_counter = 0 | |
self.eval_counter = 0 | |
self.turn_counter = 0 | |
# Dask communication | |
self.client = get_client() | |
self.worker = get_worker() | |
self.state_q = state_q | |
self.eval_q = eval_q | |
self.train_q = train_q | |
def _send(self): | |
while self.alive.is_set(): | |
state = self.mcts.get() | |
future = self.client.scatter(state) | |
self.state_q.put(future) | |
self.state_counter += 1 | |
def _recv(self): | |
while self.alive.is_set(): | |
eval = self.eval_q.get() | |
self.client.gather(eval) | |
self.mcts.put(eval) | |
self.eval_counter += 1 | |
def _monitor(self): | |
while self.alive.is_set(): | |
sleep(1) | |
print(f'[Player {self.worker.id[7:14]}] {self.mcts}') | |
def _loop(self): | |
while self.alive.is_set(): | |
self.mcts.grow() | |
self.state_counter = 0 | |
self.eval_counter = 0 | |
record = self.mcts.act() | |
record_future = self.client.scatter(record) | |
self.train_q.put(record_future) | |
self.turn_counter += 1 | |
def loop(self): | |
self.alive.set() | |
self._futures = {'monitor': self.thread_pool.submit(self._monitor), | |
'send': self.thread_pool.submit(self._send), | |
'recv': self.thread_pool.submit(self._recv), | |
'loop': self.thread_pool.submit(self._loop)} | |
def watch(self): | |
for future in thr_as_completed(self._futures.values()): | |
try: | |
future.result() | |
except Exception as ex: | |
self.kill() | |
raise ChildProcessError from ex | |
return | |
def kill(self): | |
self.alive.clear() | |
for future in thr_as_completed(self._futures.values()): | |
try: | |
future.result() | |
except Exception as e: | |
continue | |
class Model: | |
def __init__(self): | |
pass | |
def __call__(self, state): | |
sleep(0.0001) | |
eval = np.array(10) | |
return eval | |
def load(self, model): | |
pass | |
def serialize(self): | |
return np.array(100) | |
class Evaluator: | |
@classmethod | |
def start(cls, *args, **kwargs): | |
evaluator = cls(*args, **kwargs) | |
evaluator.loop() | |
return evaluator.watch() | |
def _monitor(self): | |
while self.alive.is_set(): | |
sleep(1) | |
print(f'[Evaluator {self.worker.id[7:14]}] state/eval/update: ' | |
f'{self.state_counter}/{self.eval_counter}/{self.update_counter}') | |
def __init__(self, model, state_q, eval_q): | |
self.model = model | |
self.thread_pool = ThreadPoolExecutor(max_workers=3) | |
self.alive = thr.Event() | |
self.state_counter = 0 | |
self.eval_counter = 0 | |
self.update_counter = 0 | |
# Dask communicaiton | |
self.client = get_client() | |
self.worker = get_worker() | |
self.sub = Sub('model_q') | |
self.state_q = state_q | |
self.eval_q = eval_q | |
def _loop(self): | |
while self.alive.is_set(): | |
state = self.state_q.get() | |
self.client.gather(state) | |
self.state_counter += 1 | |
eval = self.model(state) | |
future = self.client.scatter(eval) | |
self.eval_q.put(future) | |
self.eval_counter += 1 | |
try: | |
model_state = self.sub.get(timeout=0.) | |
except tornado.util.TimeoutError: | |
continue | |
if model_state is not None: | |
model_state = self.client.gather(model_state) | |
self.model.load(model_state) | |
self.update_counter += 1 | |
def loop(self): | |
self.alive.set() | |
self._futures = {'loop': self.thread_pool.submit(self._loop), | |
'monitor': self.thread_pool.submit(self._monitor) | |
} | |
def kill(self): | |
self.alive.clear() | |
for future in thr_as_completed(self._futures.values()): | |
try: | |
future.result() | |
except Exception as e: | |
continue | |
def watch(self): | |
for future in thr_as_completed(self._futures.values()): | |
try: | |
future.result() | |
except Exception as ex: | |
self.kill() | |
raise ChildProcessError from ex | |
return | |
class ReplayBuffer: | |
@classmethod | |
def start(cls, *args, **kwargs): | |
buffer = cls(*args, **kwargs) | |
buffer.loop() | |
return buffer.watch() | |
def _monitor(self): | |
while self.alive.is_set(): | |
sleep(1) | |
print(f'[Buffer {self.worker.id[7:14]}]' | |
f' train/sample: {self.train_counter}/{self.sample_counter}') | |
def __init__(self, train_q, sample_q, max_size=10000, batch_size=32, | |
min_size=3): | |
self.max_size = max_size | |
self.min_size = min_size | |
self.batch_size = batch_size | |
self.train_counter = 0 | |
self.sample_counter = 0 | |
self.states = np.zeros((max_size, 11, 11)) | |
self.targets = np.zeros((max_size, 4, 7)) | |
self.thread_pool = ThreadPoolExecutor(max_workers=4) | |
self.alive = thr.Event() | |
self.sampling = thr.Event() | |
# Dask communication | |
self.client = get_client() | |
self.worker = get_worker() | |
self.train_q = train_q | |
self.sample_q = sample_q | |
def _recv(self): | |
while self.alive.is_set(): | |
i = self.train_counter % self.max_size | |
record_future = self.train_q.get() | |
record = self.client.gather(record_future) | |
self.train_counter += 1 | |
if self.train_counter > self.min_size: | |
self.sampling.set() | |
self.states[i] = record.state | |
self.targets[i] = record.target | |
def _send(self): | |
while self.alive.is_set(): | |
self.sampling.wait() | |
lim = min(self.train_counter, self.max_size) | |
batch = np.random.permutation(lim)[:self.batch_size] | |
sample = Record(self.states[batch], self.targets[batch]) | |
sample_future = self.client.scatter(sample) | |
self.sample_q.put(sample_future) | |
self.sample_counter += 1 | |
def loop(self): | |
self.alive.set() | |
self.sampling.clear() | |
self._futures = {'recv': self.thread_pool.submit(self._recv), | |
'send': self.thread_pool.submit(self._send), | |
'monitor': self.thread_pool.submit(self._monitor) | |
} | |
def watch(self): | |
for future in thr_as_completed(self._futures.values()): | |
try: | |
future.result() | |
except Exception as ex: | |
self.kill() | |
raise ChildProcessError from ex | |
return | |
def kill(self): | |
self.alive.clear() | |
self.sampling.set() | |
for future in thr_as_completed(self._futures.values()): | |
try: | |
future.result() | |
except Exception: | |
continue | |
class Trainer: | |
@classmethod | |
def start(cls, *args, **kwargs): | |
trainer = cls(*args, **kwargs) | |
trainer.loop() | |
return trainer.watch() | |
def __init__(self, model, sample_q): | |
self.model = model | |
self.sample_q = sample_q | |
self.thread_pool = ThreadPoolExecutor(max_workers=3) | |
self.alive = thr.Event() | |
self.iter_counter = 0 | |
# Dask communication | |
self.client = get_client() | |
self.worker = get_worker() | |
self.pub = Pub('model_q') | |
def watch(self): | |
for future in thr_as_completed(self._futures.values()): | |
try: | |
future.result() | |
except Exception as ex: | |
raise ChildProcessError from ex | |
return | |
def _monitor(self): | |
while self.alive.is_set(): | |
sleep(1) | |
print(f'[Trainer {self.worker.id[7:14]}]' | |
f' iter: {self.iter_counter}') | |
def _loop(self): | |
self.iter_counter = 0 | |
while self.alive.is_set(): | |
sample_future = self.sample_q.get() | |
sample = self.client.gather(sample_future) | |
# Train | |
self.model = self.model | |
sleep(0.001) | |
self.iter_counter += 1 | |
if self.iter_counter % 100 == 0: | |
model_state = self.model.serialize() | |
model_future = self.client.scatter(model_state) | |
self.pub.put(model_future) | |
def loop(self): | |
self.alive.set() | |
self._futures = {'loop': self.thread_pool.submit(self._loop), | |
'monitor': self.thread_pool.submit(self._monitor) | |
} | |
def kill(self): | |
self.alive.clear() | |
for future in thr_as_completed(self._futures.values()): | |
try: | |
future.result() | |
except Exception as e: | |
raise ChildProcessError from e | |
def monitor(queues): | |
while True: | |
sleep(1) | |
print(f"[Queues] {'/'.join(name for name in queues.keys())}: " | |
f"{'/'.join(str(queue.qsize()) for queue in queues.values())}") | |
if __name__ == '__main__': | |
n_evaluators = 1 | |
n_players = 1 | |
client = Client(processes=False, n_workers=5) | |
state_q = Queue(maxsize=1000) | |
eval_q = Queue(maxsize=1000) | |
train_q = Queue(maxsize=1000) | |
sample_q = Queue(maxsize=1000) | |
queues = {'state_q': state_q, | |
'eval_q': eval_q, | |
'train_q': train_q, | |
'sample_q': sample_q,} | |
model = Model() | |
futures = [] | |
futures.append(client.submit(Player.start, state_q, eval_q, train_q)) | |
futures += [client.submit(Evaluator.start, model, state_q, eval_q) | |
for _ in range(2)] | |
futures.append(client.submit(ReplayBuffer.start, train_q, sample_q)) | |
futures.append(client.submit(Trainer.start, model, sample_q)) | |
futures.append(client.submit(monitor, queues)) | |
for future in as_completed(futures): | |
try: | |
future.result() | |
except Exception as ex: | |
raise ChildProcessError from ex |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment