Skip to content

Instantly share code, notes, and snippets.

@MCarlomagno
Last active January 20, 2022 18:18
Show Gist options
  • Save MCarlomagno/33864fd1235a9217b12765be76394a99 to your computer and use it in GitHub Desktop.
Save MCarlomagno/33864fd1235a9217b12765be76394a99 to your computer and use it in GitHub Desktop.
class GymTask(Task):
# ...
def _fit(self, proxy_env: ProxyEnv, nb_steps: int) -> None:
"""Fit the RL agent."""
self._rl_agent.fit(proxy_env, nb_steps)
# ...
class MyRLAgent(RLAgent):
def __init__(self, nb_goods: int) -> None:
self.good_price_models = dict(
(good_id, GoodPriceModel()) for good_id in range(nb_goods)
)
def _pick_an_action(self) -> Any:
good_id = self._get_random_next_good()
good_price_model = self.good_price_models[good_id]
price = good_price_model.get_price_expectation()
action = [good_id, price]
return action
def _update_model(
self,
observation: Any,
reward: float,
done: bool,
info: Dict,
action: Tuple[int, int],
) -> None:
good_id, price = action
good_price_model = self.good_price_models[good_id]
outcome = reward == 1.0
good_price_model.update(outcome, price)
def _get_random_next_good(self) -> int:
return random.choice(list(self.good_price_models.keys()))
def fit(self, proxy_env: ProxyEnv, nb_steps: int) -> None:
action_counter = 0
proxy_env.reset()
while action_counter < nb_steps:
action = self._pick_an_action()
obs, reward, done, info = proxy_env.step(action)
self._update_model(obs, reward, done, info, action)
action_counter += 1
proxy_env.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment