Created
April 17, 2018 14:44
-
-
Save g-leech/1db2e1a6a28b4989a66c611fdf04e452 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
sys.setrecursionlimit(500) | |
from ai_safety_gridworlds.environments.shared.safety_game import Actions | |
from ai_safety_gridworlds.environments.shared.rl.environment import TimeStep | |
from hashlib import sha1 | |
import numpy | |
import copy | |
ACTIONS = [ a for a in Actions if a is not Actions.QUIT ] | |
def merge_two_dicts(x, y): | |
z = x.copy() | |
z.update(y) | |
return z | |
def hash_board(state) : | |
return sha1(state.observation['RGB']).hexdigest() | |
def recurse(envi, state, hashes, transitions) : | |
subtransitions, subhashes = crawl_for_transitions(envi, state, hashes) | |
transitions = merge_two_dicts(subtransitions, transitions) | |
hashes = merge_two_dicts(subhashes, hashes) | |
return transitions, hashes | |
""" | |
Returns | |
* `transitions`, a dict from (state, action, nextState) to 1 or 0. | |
* `hashMap`, a dict from state-hash to ndarray state | |
""" | |
def crawl_for_transitions(envir, lastState, hashMap) : | |
transitions = {} | |
for action in ACTIONS : | |
frozenEnv = copy.deepcopy(envir) | |
nextState = frozenEnv.step(action) | |
lastIndex = hash_board(lastState) | |
index = hash_board(nextState) | |
if index not in hashMap : | |
hashMap[index] = nextState.observation['RGB'] | |
#refresh_screen(nextState) | |
# If the state changed after the last action, transition prob was 1 | |
if not lastIndex == index : | |
transitions[(lastIndex, action, index)] = 1 | |
lastState = nextState | |
transitions, hashMap = recurse(frozenEnv, lastState, hashMap, transitions) | |
else : | |
transitions[(lastIndex, action, index)] = 0 | |
return transitions, hashMap #drop_actions(hashMap) | |
def test_transition_crawler(envi) : | |
initialState = envi.reset() | |
hashes = {} | |
transitions, hashMap = crawl_for_transitions(envi, initialState, hashes) | |
NUM_STATES_LEV_0 = 60 | |
print(len(hashMap)) | |
assert( len(hashMap) == NUM_STATES_LEV_0 ) | |
# Do we think it's impossible to go up at the start? | |
initialIndex = hash_board(initialState) | |
impossibleState = envi.step(Actions.UP) | |
impossibleIndex = hash_board(impossibleState) | |
assert( transitions[(initialIndex, Actions.UP, impossibleIndex )] == 0) | |
# And subsequently possible to go down? | |
pushState = envi.step(Actions.DOWN) | |
pushIndex = hash_board(pushState) | |
startDown = (initialIndex, Actions.DOWN, pushIndex) | |
assert( transitions[(initialIndex, Actions.DOWN, pushIndex )] == 1) | |
""" | |
# A problem: | |
envi.reset() | |
cornerState = envi.step(Actions.LEFT) | |
cornerIndex = hash_board(cornerState) | |
startToLeft = (initialIndex, Actions.LEFT, cornerIndex) | |
#('00a5f33743256bb024894ce659bfad4aca93df7f', Actions.LEFT, '33afd6afa7c715b6bea347350d3f9fbf8e747ff2') | |
assert( transitions[startToLeft] == 1) | |
""" | |
env = sokoban_game(level=0) | |
test_transition_crawler(env) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment