This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def charge(self): | |
| """ return a random amount of charge """ | |
| # the reward is a guassian distribution with unit variance around the true value 'q' | |
| value = np.random.randn() + self.q |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def sample(self): | |
| """ return a value from the the posterior normal distribution """ | |
| return (np.random.randn() / np.sqrt(self.τ_0)) + self.μ_0 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def get_state_value( state_row, state_col, start_values ): | |
| ''' calculate the state value for a single state''' | |
| # get the list of next states of the current state | |
| next_states, number_of_states = get_next_states( state_row, state_col ) | |
| state_value = 0 | |
| for next_state in next_states: | |
| # add the reward for moving to the next state (always -1) and the value of the next state | |
| state_value += (1/number_of_states) * (-1 + start_values[next_state[0],next_state[1]]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def get_state_value( state , start_values ): | |
| ''' calculate the value of the specified state using the supplied current state values | |
| - this implements equation 9 ''' | |
| # iterate over all possible actions for the state | |
| state_value = 0 | |
| for action in get_π( state ): | |
| target_state = action[0] | |
| action_probability = action[1] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def update(self,x): | |
| ''' increase the number of times this socket has been used and improve the estimate of the | |
| mean and variance by combining the single new value 'x' with the current estimate ''' | |
| n = 1 | |
| v = self.n | |
| self.α = self.α + n/2 | |
| self.β = self.β + ((n*v/(v + n)) * (((x - self.μ_0)**2)/2)) | |
| # estimate the variance - calculate the mean from the gamma hyper-parameters |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def update(self,x): | |
| ''' increase the number of times this socket has been used and improve the estimate of the | |
| value (the mean) by combining the new value 'x' with the current mean ''' | |
| self.n += 1 | |
| self.x.append(x) # append the new value to the list of samples | |
| # update the mean of the posterior | |
| self.μ_0 = ((self.τ_0 * self.μ_0) + (self.τ * np.array(self.x).sum()))/(self.τ_0 + (self.n*self.τ)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def update(self,x): | |
| ''' increase the number of times this socket has been used and improve the estimate of the | |
| variance by updating the gamma distribution's hyper-parameters using the new value 'x' ''' | |
| self.n += 1 | |
| self.x.append(x) # append the new value to the list of samples | |
| self.α = self.n/2 | |
| self.β = ((np.array(self.x) - self.μ)**2).sum()/2 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def sample(self): | |
| ''' sample from our estimated normal ''' | |
| precision = np.random.gamma(self.α, 1/self.β) | |
| if precision == 0 or self.n == 0: precision = 0.001 | |
| estimated_variance = 1/precision | |
| return np.random.normal( self.μ_0, np.sqrt(estimated_variance)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import gymnasium as gym | |
| ########################################### | |
| # Stage 1 - Initialization | |
| ########################################### | |
| # create the cartpole environment | |
| env = gym.make('CartPole-v1', render_mode="human") | |
| # run for 10 episodes |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class BabyRobotEnv_v0(gym.Env): | |
| def __init__(self): | |
| super().__init__() | |
| pass | |
| def step(self, action): | |
| state = 1 | |
| reward = -1 | |
| terminated = True |