Last active
October 10, 2023 15:59
-
-
Save Nikolaj-K/8eb79658081f1ba848e39d92688b322f to your computer and use it in GitHub Desktop.
A small pytorch routine approaching the n-armed bandit problem
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Casino strategies and some pytorch. The code is explained in the video | |
https://youtu.be/VUVVVi9CgbI | |
I think you can install pytorch via `pip` or `pip3` through | |
`pip3 install torch torchvision` | |
Some mathy reading material: | |
* http://yuanz.web.illinois.edu/teaching/IE498fa19/ | |
E.g. already in "Lecture 03: the Multi-Armed Bandit Problem" you see | |
some hard theorems regarding potential strategies. | |
Things depend on your goal ("find best arm", "maximize rewards") | |
and there's various methods (jumping around differently, as in this script, or arm elimination strategies, etc.) | |
For a nice early table, see e.g. Table 1 in | |
* http://yuanz.web.illinois.edu/teaching/IE498fa19/lec_05.pdf | |
""" | |
import numpy as np | |
import random | |
import matplotlib.pyplot as plt | |
import torch | |
class LinAlgLib: # Linear algebra'ish functions | |
def is_positive_vec(vec) -> bool: | |
return all(elem >= 0 for elem in vec) | |
def is_positive_mat(mat) -> bool: | |
return all(map(LinAlgLib.is_positive_vec, row) for row in mat) | |
def is_non_zero_vec(vec, prec) -> bool: | |
return any(abs(elem) > prec for elem in vec) | |
def l1_normalized(vec) -> list: # assumes all v elements are positive | |
PREC = 10**-8 | |
assert LinAlgLib.is_non_zero_vec(vec, PREC), f"Tried to normalize the zero vector, {vec}" | |
n = sum(abs(elem) for elem in vec) | |
return [elem / n for elem in vec] | |
def normalized_rows(mat) -> np.array: | |
""" | |
Takes any matrix of positive doubles and return a valid stochastic matrix by normalizing the rows. | |
:param mat: A matrix. | |
Example: [[17.5, -2], [2‚ 3]] is mapped to [[1, 0], [2/5‚ 3/5]] | |
""" | |
assert LinAlgLib.is_positive_mat(mat) | |
return np.array(list(map(LinAlgLib.l1_normalized, mat))) | |
class ProbLib: # ProbLibability theory'ish functions | |
def rand_bool(p): # p ... chance of returning True | |
return random.random() < p | |
def is_stochastic_matrix(mat) -> bool: | |
""" | |
Return true if mat is a a quare matrix in which all rows are normalized to 1. | |
""" | |
return all( | |
len(row) == len(mat) and abs(sum(row) - 1) < 10**-15 | |
for row in mat | |
) | |
def compute_fundamental_matrix_N(mat_Q): | |
""" | |
:param: Transient block of stochastic matrix. (=Upper left block, if matrix is in canonical form.) | |
N := \lim_{num_steps\to\infty} \sum_{s=0}^num_steps Q^s = 1/(1-Q) | |
N_{ij} Describes the chance of reaching index j from index i in any number of steps | |
""" | |
id = np.eye(len(mat_Q)) | |
return np.linalg.inv(id - mat_Q) | |
def run_walk(transition_function, exit_condition_predicate, start_state): | |
s = start_state | |
while not exit_condition_predicate(s): | |
s = transition_function(s) | |
return s | |
def weighted_vertex_transition(weights: list) -> int: | |
""" | |
Return the index of a weight in weights, with ProbLibability determined by the weight itself. | |
:param weights: List of positive numbers | |
Details: Weight don't have to be l_1-normalized. | |
""" | |
ps = LinAlgLib.l1_normalized(weights) | |
# return np.random.choice(range(len(p)), p=weights) # Numpy equivalent | |
r = random.random() # Uniformly sampled float from the interval [0, 1] | |
acc_p = 0 | |
for idx, p in enumerate(ps): | |
acc_p += p | |
if acc_p > r: | |
return idx | |
assert False # Should not be reachable | |
class Choice: | |
def get_idx_max(lst): | |
# Get index of maximal value in 'lst' | |
idx_max = 0 | |
val_max = lst[idx_max] | |
for idx, val in enumerate(lst): | |
if val > val_max: | |
idx_max = idx | |
val_max = val | |
return idx_max | |
def get_choice_argmax_epsilon(lst, _alpha): # TODO: Use _alpha, e.g. for epislon | |
EPS_ARGMAX = 0.8 # 80% chance of using argmax | |
if ProbLib.rand_bool(EPS_ARGMAX): | |
return Choice.get_idx_max(lst) # Fall back on using argmax | |
weights = NUM_ARMS * [1 / NUM_ARMS] | |
return ProbLib.weighted_vertex_transition(weights) | |
def get_choice_softmax(lst, _alpha): # TODO: Maybe use _alpha | |
# Compute softmax of lst | |
TAU = 1.5 | |
weights = np.exp(np.array(lst) / TAU) | |
weights /= np.sum(weights) | |
return ProbLib.weighted_vertex_transition(weights) | |
class Casino: | |
""" | |
Conceptualizstions/visualizations of the scenario, for one room (one "state"): | |
* https://cassies.ca/wp-content/uploads/2021/04/one-armed-bandit-slot-781x512.jpg | |
* https://miro.medium.com/v2/resize:fit:1400/format:webp/1*QtG3PRxhrP-BkB6bO6aVgw.png | |
* https://numberly.com/2021/07/9ZABndg0-image-bandit-manchot.png | |
""" | |
def __init__(self, num_rooms, num_arms_per_room, pulls_per_play): | |
self.casino_matrix = np.random.rand(num_rooms, num_arms_per_room) | |
self.__pulls_per_play = pulls_per_play # Kept constant | |
self.__update_state() | |
def is_in_state(self, idx): | |
return self.state == idx | |
def play_and_goto_random_room(self, arm_choice, room_choice): | |
if room_choice < 0: | |
pass # Ignore room choice param if negative | |
else: | |
self.state = room_choice | |
assert LinAlgLib.is_positive_mat(self.casino_matrix) | |
p_winning = self.casino_matrix[self.state][arm_choice] | |
reward = sum(ProbLib.rand_bool(p_winning) for _ in range(self.__pulls_per_play)) | |
self.__update_state() | |
return reward | |
def __update_state(self): | |
possible_states = range(len(self.casino_matrix)) | |
self.state = random.choice(possible_states) | |
def manual_run(epochs, env_casino, choice_function): | |
# Note: Implemented only for one room (as opposed to 'routine_ML') | |
IDX_ROOM = 0 | |
num = len(env_casino.casino_matrix[IDX_ROOM]) | |
if len(env_casino.casino_matrix) != 1: | |
print(f"Warning: len(env_casino.casino_matrix) = {len(env_casino.casino_matrix)}") | |
q = [0] * num # average rewards | |
counts = [0] * num | |
reward_smoothed = -1e20 | |
for idx in range(epochs): | |
# Get reward for pull | |
alpha_progress = idx / epochs | |
choice = choice_function(q, alpha_progress) | |
reward = env_casino.play_and_goto_random_room(choice, IDX_ROOM) | |
cnt = counts[choice] | |
# Perform iterative average computation (and use "learning" rate language in variable names) | |
lr = 1 / (cnt + 1) | |
q[choice] = lr * reward + (1 - lr) * q[choice] # = (cnt * q[choice] + reward) / (cnt + 1) | |
counts[choice] += 1 | |
# Append mean | |
reward_smoothed = (idx * reward_smoothed + reward) / (idx + 1) | |
yield reward_smoothed | |
def routine_ML(epochs, env_casino, choice_function): | |
num_rooms = len(env_casino.casino_matrix) | |
num_arms = len(env_casino.casino_matrix[0]) | |
WIDTH_LAYER_1 = 100 # Model size | |
model = torch.nn.Sequential( | |
torch.nn.Linear(num_rooms, WIDTH_LAYER_1), | |
torch.nn.ReLU(), | |
# torch.nn.Linear(WIDTH_LAYER_1, WIDTH_LAYER_2), | |
# torch.nn.ReLU(), | |
torch.nn.Linear(WIDTH_LAYER_1, num_arms), | |
torch.nn.ReLU(), | |
) | |
LEARNING_RATE = 1e-3 | |
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) | |
loss_function = torch.nn.MSELoss() | |
reward_smoothed = -1e20 | |
for idx in range(epochs): | |
state__torch_format = list(map(env_casino.is_in_state, range(num_rooms))) # Encoding 'state' for pytorch | |
# Get reward for pull | |
q = model(torch.Tensor(state__torch_format)) | |
alpha_progress: float = idx / epochs | |
choice = choice_function(q.data.numpy(), alpha_progress) | |
ROOM_CHOCIE = -1 # No room choice | |
reward = env_casino.play_and_goto_random_room(choice, ROOM_CHOCIE) # The value of 'reward' is used as inciental target, in this iteration, for the network | |
q_reward = [(reward if idx == choice else val) for idx, val in enumerate(q)] # Encoding empirical 'reward' at index 'choice' for diff. computation | |
# Learning | |
loss = loss_function(q, torch.Tensor(q_reward)) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() # Note for other problems: If the casino_matrix were not constant, then doing a learning step at each iteration would not give a helpful learning experience (repeated learning and, in changed situation, potential "counter-learning", and hence unlearning). | |
# We'd have to introduce batch methods before backpropagating information. | |
# Append mean | |
reward_smoothed = (idx * reward_smoothed + reward) / (idx + 1) | |
yield reward_smoothed | |
# Auxiliary logging | |
if idx % int(epochs / 20) == 0: | |
print(f"training ... {idx} / {epochs} ({round(100 * alpha_progress, 1)}%)") | |
def plot_rewards(): | |
global rewards_explore, rewards_softmax, rewards_learned | |
fig, ax = plt.subplots(1, 1) | |
fig.set_size_inches(16, 8) | |
ax.plot(range(len(rewards_explore)), rewards_explore, c='blue', label='explore') | |
ax.plot(range(len(rewards_softmax)), rewards_softmax, c='green', label='softmax') | |
ax.plot(range(len(rewards_learned)), rewards_learned, c='red', label='learned') | |
ax.set_xlabel("idx") | |
ax.set_ylabel("rewards (smoothed)") | |
ax.legend() | |
plt.show() | |
if __name__=='__main__': | |
NUM_ARMS = 10 | |
STATES = 1 | |
MAX_REWARD_PERPLAY = 10 | |
env_casino = Casino(STATES, NUM_ARMS, MAX_REWARD_PERPLAY) | |
print(f"env_casino.casino_matrix:") | |
for idx_row, row in enumerate(env_casino.casino_matrix): | |
row = list(map(lambda x: round(x, 3), row)) # cut digits | |
print(f"| row #{idx_row}: {row}, (max is {max(row)}, at col index {Choice.get_idx_max(row)})") | |
EPOCHS = 3000 # i.e. "pulls" / "trials" | |
rewards_explore = list(manual_run(EPOCHS, env_casino, Choice.get_choice_argmax_epsilon)) | |
rewards_softmax = list(manual_run(EPOCHS, env_casino, Choice.get_choice_softmax)) | |
rewards_learned = list(routine_ML(EPOCHS, env_casino, Choice.get_choice_softmax)) | |
plot_rewards() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment