Created
May 15, 2020 05:17
-
-
Save marty1885/0d1b1a1f0412b07d274e1463fa342975 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from etaler import et | |
import gym | |
env = gym.make('FrozenLake-v0', is_slippery=False) | |
# The agent is very simple. Mathmetically, think it as a weaker Q table. Biologically, it is a bunch | |
# of perimedial neurons receving input signals and direcly connected to a motor. Then the reward is | |
# used for modifing the behaivor of the neurons. | |
class Agent: | |
def __init__(self): | |
self.sp1 = et.SpatialPooler(input_shape=(256,), output_shape=(256,)) | |
self.sp1.setBoostingFactor(2) | |
self.sp1.setGlobalDensity(0.15) | |
def act(self, observation): | |
# Observatoin is just a integer in this enviroment | |
self.sdr = et.encoder.gridCell1d(observation*8) # x8 to increase the distance between values | |
self.out = self.sp1.compute(self.sdr) | |
res = self.out.reshape((4, 64)).sum(dim=1).numpy() | |
#print(res) | |
return res.argmax() | |
def learn(self, reward): | |
if reward > 0: | |
# Positive enforcment | |
self.sp1.setPermanenceInc(reward) | |
self.sp1.setPermanenceDec(reward) | |
self.sp1.learn(self.sdr, self.out) | |
elif reward < 0: | |
# Negative enforcment | |
self.sp1.setPermanenceInc(reward) | |
self.sp1.setPermanenceDec(0) | |
self.sp1.learn(self.sdr, self.out) | |
else: | |
pass | |
agent = Agent() | |
for episode in range(200): | |
print("======================Ep {}=========================".format(episode)) | |
observation = env.reset() | |
while True: | |
env.render() | |
action = agent.act(observation) | |
old_observation = observation | |
observation, reward, done, info = env.step(action) | |
# A new reward function (the standared reward does not work well | |
# with our simple agent) (we only support reward =0 >0 and <0) | |
if(reward == 0.0 and not done): # Case: we are on a ice tile | |
if observation != old_observation: | |
htm_reward = 0.02 | |
else: | |
htm_reward = -0.3 | |
elif(reward == 0.0 and done): # Case: falls into a hole | |
htm_reward = -0.3 | |
else: # Case: We are at the goal | |
htm_reward = 0.5 | |
agent.learn(htm_reward) | |
#print(htm_reward) | |
if done: | |
break | |
# exit(0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment