Created
May 21, 2019 03:54
-
-
Save evanthebouncy/f5336cfbd91cb45b4068eb4160723513 to your computer and use it in GitHub Desktop.
beam search
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import math | |
import numpy as np | |
# a random mock-up environment with a single state of a float | |
# the goal is to get the float as close to 0 as possible with 2 possible moves | |
# x <- x + 1 | |
# x <- cos(x) | |
class Env: | |
def __init__(self): | |
self.state = random.random() | |
def copy(self): | |
copied = Env() | |
copied.state = self.state | |
return copied | |
# simply return the next state | |
def step(self, action): | |
if action == 0: | |
self.state = self.state + 1 | |
if action == 1: | |
self.state = math.cos(self.state) | |
return self.state | |
def cost(self): | |
return abs(self.state) | |
# an agent that's completely made up and not optimal at all | |
class Agent: | |
def __init__(self): | |
self.prob_map = dict() | |
# give an arbitrary probability distribution of the 2 actions | |
def get_action_prob(self, state): | |
# generate a deterministic but arbitrary probability for acting a0 | |
# on a particular state | |
if state not in self.prob_map: | |
a0_prob = random.random() | |
a1_prob = 1.0 - a0_prob | |
self.prob_map[state] = (a0_prob, a1_prob) | |
return self.prob_map[state] | |
# take a random action | |
def act(self, state): | |
action_prob = self.get_action_prob(state) | |
return | |
# all games end in 10 steps | |
N_STEP = 10 | |
# forward sample 10 times | |
def forward_sample(env_in, agent, n_samples): | |
traces = [] | |
for i in range(n_samples): | |
env = env_in.copy() | |
trace = [env.state] | |
trace_logpr = 0.0 | |
for t_step in range(N_STEP): | |
state = env.state | |
action_prob = agent.get_action_prob(state) | |
action = np.random.choice([0, 1], p=action_prob) | |
action_logpr = np.log(action_prob[action]) | |
trace_logpr += action_logpr | |
trace.append(action) | |
trace.append(trace_logpr) | |
traces.append(trace) | |
return traces | |
# beam search 10 times | |
def beam_search(env_in, agent, beam_width): | |
# a beam is a list of beam-particles | |
# each beam particle is (env, log_prob, [actions]) | |
beams = [(env_in.copy(), 0.0, [])] | |
for t_step in range(N_STEP): | |
new_beam_candidates = [] | |
for beam_particle in beams: | |
beam_env, beam_logpr, beam_actions = beam_particle | |
action_logprobs = np.log(agent.get_action_prob(beam_env.state)) | |
for action in [0, 1]: | |
copy_env = beam_particle[0].copy() | |
copy_env.step(action) | |
new_beam_particle = (copy_env, beam_logpr + action_logprobs[action], beam_actions + [action]) | |
new_beam_candidates.append(new_beam_particle) | |
# sort the candidates by log-probability | |
new_beam_candidates = sorted(new_beam_candidates, key = lambda x: -x[1]) | |
# take the highest logpr candidates | |
beams = new_beam_candidates[:beam_width] | |
return beams | |
if __name__ == '__main__': | |
env = Env() | |
agent = Agent() | |
print ("start state ", env.state) | |
print ("=============== PERFORMING SAMPLING") | |
sample_traces = forward_sample(env, agent, 10) | |
for trace in sample_traces: | |
print (f"trace actions {trace[1:-1]} generated with loglikelihood {trace[-1]}") | |
best_logpr = max([trace[-1] for trace in sample_traces]) | |
print (f"maximum loglikelihood {best_logpr}") | |
print ("============= PERFORMING BEAM SEARCH") | |
beams = beam_search(env, agent, 10) | |
for beam in beams: | |
print (f"beam actions {beam[2]} generated with loglikelihood {beam[1]}") | |
best_logpr = max([beam[1] for beam in beams]) | |
print (f"maximum loglikelihood {best_logpr}") | |
print ("VALIDATING BEAM SEARCH USUALLY RETURN BETTER LOGPR ON ITS BEST SEQUENCE") | |
sample_better, beam_better = 0, 0 | |
for i in range(1000): | |
env = Env() | |
agent = Agent() | |
sample_traces = forward_sample(env, agent, 10) | |
beams = beam_search(env, agent, 10) | |
best_sample_logpr = max([trace[-1] for trace in sample_traces]) | |
best_beam_logpr = max([beam[1] for beam in beams]) | |
if best_sample_logpr > best_beam_logpr: | |
sample_better += 1 | |
else: | |
beam_better += 1 | |
print (f"out of 1000 samples, sample is better {sample_better} times and beam is better {beam_better} times") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment