Skip to content

Instantly share code, notes, and snippets.

@evanthebouncy
Created March 13, 2019 15:38
Show Gist options
  • Save evanthebouncy/9cebb89da3ac2b2174de12641a7eb99b to your computer and use it in GitHub Desktop.
Save evanthebouncy/9cebb89da3ac2b2174de12641a7eb99b to your computer and use it in GitHub Desktop.
def train_dagger(env, teacher, student):
init_state = env.reset()
s_a_agg = []
for i in range(100):
# learning
print ("learning . . . ", i)
trace = get_rollout(env, init_state, student)
state_sample = [x[0] for x in trace]
action_sample = [teacher.act(x[0]) for x in trace]
s_a_agg += list(zip(state_sample, action_sample))
for i in range(10):
sub_sample = random.sample(s_a_agg, 40)
sub_states, sub_actions = [x[0] for x in sub_sample], [x[1] for x in sub_sample]
student.learn_supervised(sub_states, sub_actions)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment