Last active
January 20, 2022 18:18
-
-
Save MCarlomagno/33864fd1235a9217b12765be76394a99 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class GymTask(Task): | |
# ... | |
def _fit(self, proxy_env: ProxyEnv, nb_steps: int) -> None: | |
"""Fit the RL agent.""" | |
self._rl_agent.fit(proxy_env, nb_steps) | |
# ... | |
class MyRLAgent(RLAgent): | |
def __init__(self, nb_goods: int) -> None: | |
self.good_price_models = dict( | |
(good_id, GoodPriceModel()) for good_id in range(nb_goods) | |
) | |
def _pick_an_action(self) -> Any: | |
good_id = self._get_random_next_good() | |
good_price_model = self.good_price_models[good_id] | |
price = good_price_model.get_price_expectation() | |
action = [good_id, price] | |
return action | |
def _update_model( | |
self, | |
observation: Any, | |
reward: float, | |
done: bool, | |
info: Dict, | |
action: Tuple[int, int], | |
) -> None: | |
good_id, price = action | |
good_price_model = self.good_price_models[good_id] | |
outcome = reward == 1.0 | |
good_price_model.update(outcome, price) | |
def _get_random_next_good(self) -> int: | |
return random.choice(list(self.good_price_models.keys())) | |
def fit(self, proxy_env: ProxyEnv, nb_steps: int) -> None: | |
action_counter = 0 | |
proxy_env.reset() | |
while action_counter < nb_steps: | |
action = self._pick_an_action() | |
obs, reward, done, info = proxy_env.step(action) | |
self._update_model(obs, reward, done, info, action) | |
action_counter += 1 | |
proxy_env.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment