Skip to content

Instantly share code, notes, and snippets.

@jeanmidevacc
Created June 24, 2024 15:56
Show Gist options
  • Save jeanmidevacc/e4f32253b7dbe6c3ac7ef1af408cfaca to your computer and use it in GitHub Desktop.
Save jeanmidevacc/e4f32253b7dbe6c3ac7ef1af408cfaca to your computer and use it in GitHub Desktop.
suika_baseline_agents.py
from datetime import datetime
import random
import pandas as pd
class RandomAgent():
def __init__(self):
self.creation_date = datetime.utcnow()
self.tag = "random"
def get_action(self, observation):
return random.randint(63, 522)
class BaselineAgent():
def __init__(self):
self.creation_date = datetime.utcnow()
self.tag = "baseline1"
def get_action(self, observation):
dfp_particle_states = pd.DataFrame(observation["particle_states"]).dropna()
next_particle = observation["next_particle"]
if len(dfp_particle_states) > 0:
dfp_particle_states_like_next_particle = dfp_particle_states[dfp_particle_states["n"] == next_particle]
if len(dfp_particle_states_like_next_particle) > 0:
dfp_particle_states_like_next_particle.sort_values("position_y", ascending=True, inplace=True)
position_x = dfp_particle_states_like_next_particle["position_x"].tolist()[0]
return position_x
return random.randint(63, 522)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment