Skip to content

Instantly share code, notes, and snippets.

@vwxyzjn
Created October 25, 2024 18:41
Show Gist options
  • Save vwxyzjn/473b8bb6ddd65aad886559e769eac292 to your computer and use it in GitHub Desktop.
Save vwxyzjn/473b8bb6ddd65aad886559e769eac292 to your computer and use it in GitHub Desktop.
import queue
import threading
import time
class Agent():
def __init__(self):
self.param = 1
def learn(self, data):
self.param += 1
def query_generator_fn():
for i in range(1, 100):
yield i
ITER = 7
batch_size = 32
agent = Agent()
data_Q = queue.Queue(maxsize=1)
param_and_query_Q = queue.Queue(maxsize=1)
def actor():
for i in range(1, ITER + 1):
params, query = param_and_query_Q.get()
data = params
print(f"[actor] generating data π_{params} -> p_{query} D_π_{data}")
time.sleep(1) # simulate data generation
data_Q.put((query, data))
actor_thread = threading.Thread(target=actor)
actor_thread.start()
# initial param put
generator = query_generator_fn()
next_queries = next(generator)
param_and_query_Q.put((agent.param, next_queries))
# cleanba style stuff
async_mode = True
start_time = time.time()
for g in range(1, ITER + 1):
queries = next_queries
if async_mode:
if g != 1:
next_queries = next(generator)
param_and_query_Q.put((agent.param, queries))
else:
if g != 1:
next_queries = next(generator)
param_and_query_Q.put((agent.param, next_queries)) # note the indent here is different
_, data = data_Q.get()
old_param = agent.param
agent.learn(data)
time.sleep(1) # simulate training
print(f"--[leaner] get π_{old_param} -> p_{queries} D_π_{data} -> π_{agent.param}, time: {time.time() - start_time}")
actor_thread.join()
@vwxyzjn
Copy link
Author

vwxyzjn commented Oct 25, 2024

This gist demonstrates the core async RL concept presented in https://arxiv.org/abs/2410.18252:

Assuming rollout and training both take 1 second, by running things in an async mode, we can parallelize the actor and learner's computation:

[actor] generating data π_1 -> p_1 D_π_1
[actor] generating data π_1 -> p_1 D_π_1
--[leaner] get π_1 ->  p_1 D_π_1 -> π_2, time: 2.0022785663604736
[actor] generating data π_2 -> p_1 D_π_2
--[leaner] get π_2 ->  p_1 D_π_1 -> π_3, time: 3.003446578979492
[actor] generating data π_3 -> p_2 D_π_3
--[leaner] get π_3 ->  p_2 D_π_2 -> π_4, time: 4.004631042480469
[actor] generating data π_4 -> p_3 D_π_4
--[leaner] get π_4 ->  p_3 D_π_3 -> π_5, time: 5.005782842636108
[actor] generating data π_5 -> p_4 D_π_5
--[leaner] get π_5 ->  p_4 D_π_4 -> π_6, time: 6.006959915161133
[actor] generating data π_6 -> p_5 D_π_6
--[leaner] get π_6 ->  p_5 D_π_5 -> π_7, time: 7.008115530014038
--[leaner] get π_7 ->  p_6 D_π_6 -> π_8, time: 8.00928544998169

Versus if we run in sync mode:

[actor] generating data π_1 -> p_1 D_π_1
--[leaner] get π_1 ->  p_1 D_π_1 -> π_2, time: 2.002211332321167
[actor] generating data π_2 -> p_2 D_π_2
--[leaner] get π_2 ->  p_1 D_π_2 -> π_3, time: 4.004531145095825
[actor] generating data π_3 -> p_3 D_π_3
--[leaner] get π_3 ->  p_2 D_π_3 -> π_4, time: 6.006798267364502
[actor] generating data π_4 -> p_4 D_π_4
--[leaner] get π_4 ->  p_3 D_π_4 -> π_5, time: 8.009055852890015
[actor] generating data π_5 -> p_5 D_π_5
--[leaner] get π_5 ->  p_4 D_π_5 -> π_6, time: 10.011354207992554
[actor] generating data π_6 -> p_6 D_π_6

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment