Last active
June 18, 2020 05:34
-
-
Save koozdra/ed3b5456d79576f173ffebcc9afe3f28 to your computer and use it in GitHub Desktop.
post sign flow q-learning
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
shuffle_array = arr => | |
arr | |
.map(a => [a, Math.random()]) | |
.sort(([, lr], [, rr]) => lr - rr) | |
.map(([a]) => a); | |
user_action_for_state = (state, randy) => { | |
if (state === 'p' || state === 'ps' || state === 'psa' || state === 'psam') { | |
if (randy < 0.7) { | |
return primary_action_names[state[state.length - 1]]; | |
} else { | |
return 'skip'; | |
} | |
} | |
if (state.startsWith('a') || state.startsWith('m')) { | |
return 'skip'; | |
} | |
if (randy < 0.5) { | |
return primary_action_names[state[state.length - 1]]; | |
} else { | |
return 'skip'; | |
} | |
}; | |
const actions = ['p', 's', 'a', 'm']; | |
const rewards = { | |
share: 2, | |
skip: 1, | |
promote: 3, | |
member: 3, | |
action: 2 | |
}; | |
const primary_action_names = { | |
p: 'promote', | |
a: 'action', | |
s: 'share', | |
m: 'member' | |
}; | |
possible_actions = (state, all_actions) => | |
all_actions.filter(action => !state.includes(action)); | |
get_q_value = (q_table, state, action) => q_table[state + action] || 0; | |
update_q_table_trajectory = ( | |
q_table, | |
learning_rate, | |
discount_factor, | |
trajectory | |
) => { | |
trajectory.forEach(([state, action, reward]) => { | |
update_q_table( | |
q_table, | |
learning_rate, | |
discount_factor, | |
state, | |
action, | |
reward | |
); | |
}); | |
}; | |
update_q_table = ( | |
q_table, | |
learning_rate, | |
discount_factor, | |
state, | |
action, | |
reward | |
) => { | |
const state_action = state + action; | |
const next_q_values = possible_actions( | |
state_action, | |
actions | |
).map(possible_action => get_q_value(q_table, state_action, possible_action)); | |
max_next_q = next_q_values.length > 0 ? Math.max(...next_q_values) : 0; | |
current_q_value = get_q_value(q_table, state, action); | |
q_table[state_action] = | |
current_q_value + | |
learning_rate * (reward + discount_factor * max_next_q - current_q_value); | |
}; | |
select_random_action = possibles => | |
possibles[Math.floor(Math.random() * possibles.length)]; | |
select_greedy_q_table = (q_table, possibles, state) => { | |
const q_values = possibles.map(action => get_q_value(q_table, state, action)); | |
return possibles[q_values.indexOf(Math.max(...q_values))]; | |
}; | |
generate_trajectory = (q_table, state, acc) => { | |
if (state.length > 3) { | |
return acc; | |
} | |
const is_explore = Math.random() < exploration_rate; | |
const possibles = shuffle_array(possible_actions(state, actions)); | |
agent_action = is_explore | |
? select_random_action(possibles) | |
: select_greedy_q_table(q_table, possibles, state); | |
next_state = state + agent_action; | |
user_page_action_name = user_action_for_state(next_state, Math.random()); | |
reward = rewards[user_page_action_name]; | |
return generate_trajectory(q_table, next_state, [ | |
...acc, | |
[state, agent_action, reward] | |
]); | |
}; | |
const iterations = 1000; | |
const trials = 100; | |
const learning_rate = 0.9; | |
const discount_factor = 0.9; | |
const exploration_rate = 0.01; | |
const q_table = {}; | |
trial_rewards = []; | |
trial_paths = []; | |
Array.from({ length: trials }, () => { | |
trajectory_rewards = []; | |
trajectory_paths = []; | |
Array.from({ length: iterations }, () => { | |
const trajectory = generate_trajectory(q_table, '', []); | |
const total_reward = trajectory.reduce( | |
(accumulator, [, , reward]) => accumulator + reward, | |
0 | |
); | |
const pretty_trajectory = trajectory | |
.map(([state, action, reward]) => `${state}(${action})[${reward}]`) | |
.join(', '); | |
trajectory_rewards.push(total_reward); | |
const [final_state, final_action] = trajectory[trajectory.length - 1]; | |
trajectory_paths.push(final_state + final_action); | |
update_q_table_trajectory( | |
q_table, | |
learning_rate, | |
discount_factor, | |
trajectory | |
); | |
}); | |
trial_rewards.push(trajectory_rewards); | |
trial_paths.push(trajectory_paths); | |
}); | |
// const output_rewards = trial_rewards.map(a => a.join(',')).join(','); | |
// const output_paths = trial_paths.map(a => a.join(',')).join(','); | |
const window_size = 100; | |
const all_keys = {}; | |
const stat_frequencies = []; | |
Array.from({ length: iterations - window_size }, (x, i) => { | |
let window_accumulator = []; | |
trial_paths.forEach(trajectory_paths => { | |
const trajectory_window = trajectory_paths.slice(i, i + window_size); | |
// window_accumulator.push(trajectory_paths.slice(i, i + window_size)); | |
window_accumulator = [...window_accumulator, ...trajectory_window]; | |
}); | |
// console.log(window_accumulator); | |
const frequency = {}; | |
window_accumulator.forEach(stat => { | |
all_keys[stat] = true; | |
curr_frequency_value = frequency[stat]; | |
frequency[stat] = curr_frequency_value ? curr_frequency_value + 1 : 1; | |
}); | |
stat_frequencies.push(frequency); | |
}); | |
// console.log(stat_frequencies); | |
frequency_stat_keys = Object.keys(all_keys); | |
header = frequency_stat_keys.join(','); | |
console.log(header); | |
stat_frequencies.forEach(frequency_stat => { | |
const output_line = []; | |
frequency_stat_keys.forEach(frequency_stat_key => { | |
const frequency_stat_value = frequency_stat[frequency_stat_key]; | |
output_line.push(frequency_stat_value || 0); | |
}); | |
console.log(output_line.join(',')); | |
}); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment