Forked from WilliamJou/gist:344e8cc879f4d51629bde3dcf166c99b
Created
July 13, 2019 12:53
-
-
Save gaybro8777/867b2337a190cff5b00e2861a4091365 to your computer and use it in GitHub Desktop.
Updated Faucet
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
workspace() | |
importall POMDPs | |
using POMDPToolbox | |
using StatsBase | |
#using SARSOP | |
using BasicPOMCP | |
using D3Trees | |
using ParticleFilters | |
const DISH = 1 | |
const HAND = 2 | |
const POT = 3 | |
const TEMPS = 27:3:52 #temperature range in celsius | |
const TINDEX = Dict{Int, Int}(t=>i for (i,t) in enumerate(TEMPS)) | |
const FLOWS = 5:10:95 | |
const FINDEX = Dict{Int, Int}(t=>i for (i,t) in enumerate(FLOWS)) | |
const USERS = 1:1:4 | |
struct FState | |
task::Int | |
time::Int | |
prev_temp::Int | |
prev_flow::Int #not sure if this is necessary | |
user::Int | |
end | |
struct FPOMDP <: POMDP{FState, Tuple{Int,Int}, Tuple{Int, Int, Float64, Int}} # POMDP{State, Action, Observation} | |
max_time::Int | |
end | |
p = FPOMDP(10) | |
const DTEMP = Dict{Int, Int}(DISH=>40, HAND=>27, POT=>46) | |
const DFLOW = Dict{Int, Int}(DISH=>65, HAND=>55, POT=>75) #desired states of flow for each of these tasks | |
const METAL = [.05,.1,.75,.85] | |
const U_WEIGHTS = Dict{Int, Any}(1=>pweights([.4,.2,.2,.2]), 2=>pweights([.2,.4,.2,.2]), 3=>pweights([.2,.2,.4,.2]), 4=>pweights([.2,.2,.2,.4])) | |
isterminal(p::FPOMDP, s::FState) = s.time > p.max_time | |
states(p::FPOMDP) = vec(collect(FState(task, time, pt, pf, u) for task in [DISH, HAND, POT], time in 0:p.max_time, pt in TEMPS, pf in FLOWS, u in USERS)) | |
#user is an integer vector of [picky, resource-conscious, patient, and doubting] | |
n_states(p::FPOMDP) = length(TEMPS)*(p.max_time+1)*3*length(FLOWS)*length(USERS) | |
const SINDEX = Dict{FState, Int}(s=>i for (i,s) in enumerate(states(p))) | |
state_index(p::FPOMDP, s::FState) = SINDEX[s] | |
actions(p::FPOMDP) = vec(collect((t,f) for t in TEMPS, f in FLOWS)) | |
n_actions(p::FPOMDP) = length(TEMPS)*length(FLOWS) | |
const AINDEX = Dict(a=>i for (i,a) in enumerate(actions(p))) | |
action_index(p::FPOMDP, a::Int) = AINDEX[a] | |
observations(p::FPOMDP) = vec(collect((t,f,m,w) for t in TEMPS, f in FLOWS, m in METAL, w in USERS)) | |
n_observations(p::FPOMDP) = length(TEMPS)*length(FLOWS)*length(METAL)*length(USERS) | |
const OINDEX = Dict(o=>i for (i,o) in enumerate(observations(p))) | |
obs_index(p::FPOMDP, o::Tuple{Int,Int,Float64, Int}) = OINDEX[o] | |
function transition(p::FPOMDP, s::FState, a::Tuple{Int,Int}) | |
transition = FState(s.task, s.time+1, a[1], a[2], s.user) | |
stay = FState(s.task, s.time+1, s.prev_temp, s.prev_flow, s.user) | |
SparseCat([transition, stay], [.9,.1]) | |
#SparseCat([FState(s.task, s.time+1, a[1], a[2], s.user)],[1.0]) | |
end | |
function observation(p::FPOMDP, a::Tuple{Int,Int}, sp::FState) | |
#output of metal sensor dependent on if task is pot, dish, or hand | |
if sp.task == 3 | |
m_weight = pweights([.05,.05,.45,.45]) #weight for if it is a pot | |
elseif sp.task == 2 | |
m_weight = pweights([.3,.3,.3,.1]) #weight for dishwashing | |
else | |
m_weight = pweights([.45,.45,.05,.05]) #weight for handwashing | |
end | |
m = sample(METAL,m_weight) | |
#user 1 = Picky, only adjust reward function | |
#user 2 = Resource Conscious, will change if outputs are higher than they want | |
u_val = sample([1,2,3,4], U_WEIGHTS[sp.user]) | |
if sp.user == 3 | |
if sp.time > 4 && (a[1] != DTEMP[sp.task] || a[2] != DFLOW[sp.task]) | |
change = (DTEMP[sp.task], DFLOW[sp.task], m,u_val) | |
leave = (0,0,m,u_val) | |
return SparseCat([change, leave], [.95, 0.05]) # list of observations and associated probabilities/items | |
end | |
elseif sp.user == 4 | |
if sp.time > 1 && (a[1] != DTEMP[sp.task] || a[2] != DFLOW[sp.task]) | |
change = (DTEMP[sp.task], DFLOW[sp.task], m,u_val) | |
leave = (0,0,m,u_val) | |
return SparseCat([change, leave], [.75, 0.25]) # list of observations and associated probabilities/items | |
end | |
else | |
if sp.time > 2 && (a[1] != DTEMP[sp.task] || a[2] != DFLOW[sp.task]) | |
change = (DTEMP[sp.task], DFLOW[sp.task], m,u_val) | |
leave = (0,0,m,u_val) | |
return SparseCat([change, leave], [.95, 0.05]) # list of observations and associated probabilities/items | |
end | |
leave = (0,0,m,u_val) | |
return SparseCat([leave], [1.0]) | |
end | |
end | |
function reward(p::FPOMDP, s::FState, a::Tuple{Int,Int}) | |
if a[1] == DTEMP[s.task] && a[2] == DFLOW[s.task] | |
return 4.0 | |
elseif a[1] == DTEMP[s.task] | |
return 2.0 | |
elseif a[2] == DFLOW[s.task] | |
return 1.0 | |
else | |
return -10.0 | |
end | |
end | |
initial_user = sample([1,2,3,4], pweights([.25,.25,.25,.25])) | |
#initial_state_distribution(p::FPOMDP) = SparseCat([FState(t, 0, 0, 0,u) for t in [DISH, HAND, POT], u in USERS], [0.083, 0.083,.083, 0.083, 0.083,.083, 0.083, 0.083,.083, 0.083, 0.083,.083]) | |
initial_state_distribution(p::FPOMDP) = SparseCat([FState(t, 0, 0, 0,initial_user) for t in [DISH, HAND, POT]], [.3,.3,.3]) | |
# policy = RandomPolicy(p) | |
solver = POMCPSolver(c=100) | |
policy = solve(solver, p) | |
#function my_policy(b::ParticleCollection) | |
# s = rand(Base.GLOBAL_RNG, b) | |
# return (DTEMP[s.task], DFLOW[s.task]) | |
#end | |
#policy = FunctionPolicy(my_policy) | |
#up = SIRParticleFilter(p, 1000) | |
for (b, s, a, r, o) in stepthrough(p, policy, "bsaro") | |
frac_hand = length(filter(s->s.task==HAND, particles(b)))/n_particles(b) | |
frac_dish = length(filter(s->s.task==DISH, particles(b)))/n_particles(b) | |
frac_pot = length(filter(s->s.task==POT, particles(b)))/n_particles(b) | |
@show frac_hand | |
@show frac_dish | |
@show frac_pot | |
@show s | |
@show a | |
@show r | |
@show o | |
end | |
# inchrome(D3Tree(policy)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment