Last active
January 12, 2016 15:17
-
-
Save robinvanemden/bc7c25adf75c9e67038c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
#============================================================================== | |
# Reset IPython - may delete the following when not making use of Spyder IDE | |
from IPython import get_ipython | |
get_ipython().magic('reset -sf') | |
#============================================================================== | |
import sys,random | |
import json | |
import asyncio,aiohttp,pymongo,time | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from scipy.interpolate import UnivariateSpline | |
from scipy import misc | |
#============================================================================== | |
MONGO_IP = "78.46.212.194" | |
MONGO_PORT = 27017 | |
SB_BASE_URL = "https://strm.mnds.org:8080" | |
#============================================================================== | |
print_progress = 1 # print progress to console? | |
exp_key = "36c067b1d3" # streaminbandit key | |
exp_id = 1 # experiment id | |
question_nr = 891441222 # question id | |
#============================================================================== | |
x_value_min = 0 # Domain, min X | |
x_value_max = 100 # Domain, max X, is of influence on | |
# Start value and *Amplitude* ! | |
rounded_to = 0 # X shown to subject is rounded | |
# to the given number of decimals. | |
# Rounded_to has to round to at least | |
# an order of magnitude divided by 2 | |
# smaller values than x_value_max. | |
# If not, the displayed x | |
# would end up being 0.0 all of the time. | |
y_line = 850 # y of decoy line | |
x0 = 0.3 * x_value_max # Start value | |
A = 0.07 * x_value_max # Amlitude | |
#============================================================================== | |
parallel_processes = 10 # how many parallel requests | |
repeats_per_process = 400 # how many reapeats per process | |
# total: 10 x 400 = 4000 | |
max_time_set_reward = 0.1 # max delay before reward is choosen | |
choice_probability = 0.95 # probability that agent chooses a value | |
#============================================================================== | |
# | |
# In the current simulation we make use of a Decoy decision model from: | |
# | |
# Trueblood, Jennifer S., Scott D. Brown, and Andrew Heathcote. | |
# "The multiattribute linear ballistic accumulator model of context effects in | |
# multialternative choice." | |
# Psychological review 121.2 (2014): 179. | |
# | |
# The model of Trueblood et al is itself based on data from previous | |
# experiments. We will use it here to set the parameters of our simulation. | |
# | |
#============================================================================== | |
# import matrix graph from article | |
decoy_img = misc.imread("Decoy_Matrix.bmp") | |
# create brightness matrix | |
decoy_img_brightness = decoy_img.sum(axis=2) / decoy_img.shape[2] | |
# set decoy accroding to brightness along the chosen line | |
# approximation, brightness seems not to be 1:1 to p, | |
# but good enough for our purpose | |
decoy_y = decoy_img_brightness[y_line] / 255 | |
# create xvalues, from x min to x max | |
decoy_x = np.arange(x_value_min, decoy_y.size) | |
decoy_x = (decoy_x / decoy_y.size) * x_value_max | |
# create a spline approximation of the decoy values | |
# based on brightness data | |
spline = UnivariateSpline (decoy_x, decoy_y-0.04, k=4, s=1) | |
spline_roots = spline.derivative().roots() | |
# relevant local max - hardcoded, | |
# might be neat to use second deriv pos check, ie | |
# spline.derivatives(spline.derivative().roots()[x])[1] == postive | |
spline_local_max = spline_roots[1] | |
#============================================================================== | |
start_at_t = 0 | |
start = time.time() | |
def end(): | |
end = time.time() | |
secs = end - start | |
#msecs = secs * 1000 # millisecs | |
print("elapsed time:" + str(secs-3.1)) | |
def main(): | |
req_version = (3,5) | |
cur_version = sys.version_info | |
if cur_version >= req_version: | |
# Necessary to run code more than once in same console in Win | |
if sys.platform == 'win32': | |
loop = asyncio.ProactorEventLoop() | |
asyncio.set_event_loop(loop) | |
else : | |
loop = asyncio.SelectorEventLoop() | |
asyncio.set_event_loop(loop) | |
future = asyncio.Future() | |
asyncio.ensure_future(do_async(future)) | |
loop.run_until_complete(future) | |
loop.stop() | |
loop.close() | |
end() | |
do_chart() | |
else: | |
print("This script makes use of asyncio, introduced in Python 3.5") | |
print("Consider upgrading your Python interpreter to 3.5 or higher.") | |
sys.exit(0) | |
async def do_async(future): | |
for i in range(parallel_processes): | |
asyncio.ensure_future(get_action(future)) | |
async def get_action(future): | |
for i in range(repeats_per_process): | |
request = SB_BASE_URL + "/" + str(exp_id) | |
request += "/getAction.json?key="+exp_key | |
request += "&context={\"question\":"+str(question_nr) | |
request += ",\"x0\":"+str(x0)+",\"A\":"+str(A)+"}" | |
conn = aiohttp.TCPConnector(verify_ssl=False) | |
session = aiohttp.ClientSession(connector=conn) | |
response = await session.get(request,encoding='utf-8') | |
body = await response.read(decode=False) | |
response.close() | |
conn.close() | |
obj = json.loads(str(body.decode('utf-8'))) | |
t = obj["action"]["t"] | |
x = obj["action"]["x"] | |
global start_at_t | |
if (start_at_t==0): | |
start_at_t = t | |
if print_progress: | |
print("#####"+str(t)) | |
asyncio.ensure_future(set_reward(t,x,future)) | |
async def set_reward(t,x,future): | |
# just a failsafe, makes sure value not out of bounds, mostly superfluous | |
if(x > x_value_max): x = x_value_max | |
if(x < x_value_min): x = x_value_min | |
# in our experiment, displayed x will often be rounded to integers | |
display_x = np.round(x,rounded_to) | |
# randomly pick A's or B's - p changes according to position of decoy C | |
# (funct. np.random.binomial would have been other option for same result) | |
sampleChoice = np.random.choice(["B","A"], | |
size=(1,), | |
p=[spline(display_x), | |
1-spline(display_x)]) | |
# did agent/subject choose A or B | |
if sampleChoice == 'B': | |
y = 1.0 # chooses B! Hurray! | |
else: | |
y = 0.0 # did not choose B | |
if np.random.binomial(1, choice_probability, 1)==1: | |
request = SB_BASE_URL + "/" + str(exp_id) + "/setReward.json" | |
request += "?key="+exp_key | |
request += "&context={\"question\":"+str(question_nr) | |
request += ",\"x0\":"+str(x0)+",\"A\":"+str(A)+"}" | |
request += "&action={\"x\":" + str(float(x)) | |
request += ",\"t\":" + str(float(t)) + "}" | |
request += "&reward=" + str(float(y)) | |
await asyncio.sleep(random.uniform(0.0, max_time_set_reward)) | |
conn = aiohttp.TCPConnector(verify_ssl=False) | |
session = aiohttp.ClientSession(connector=conn) | |
response = await session.get(request,encoding='utf-8') | |
response.close() | |
conn.close() | |
if print_progress: | |
print("$$$$$$$$$$"+str(t)) | |
if t>=repeats_per_process*parallel_processes+start_at_t-1: | |
await asyncio.sleep(max_time_set_reward+3) | |
future.set_result(future) | |
def do_chart(): | |
client = pymongo.MongoClient(MONGO_IP, MONGO_PORT) | |
db = client.logs | |
cursor = db.logs.find({"type": "setreward","q":question_nr}) \ | |
.sort([("t", pymongo.ASCENDING)]) | |
result_list = list(cursor) | |
client.close(); | |
t = [ts['t'] for ts in result_list] | |
x0 = [xs['x0'] for xs in result_list] | |
fig = plt.figure(figsize=(15,7)) | |
ax = fig.add_subplot(1,1,1) | |
major_ticks = np.arange(0, len(t), 100) | |
minor_ticks = np.arange(0, len(t), 50) | |
ax.tick_params(which = 'both', direction = 'out') | |
ax.set_xticks(major_ticks) | |
ax.set_xticks(minor_ticks, minor=True) | |
ax.grid(which='both') | |
plt.plot(t,x0) | |
plt.ylim([0,x_value_max/2]) | |
plt.axhline(y=spline_local_max,linewidth=1, color='r',linestyle='--') | |
plt.show() | |
if __name__ == '__main__': | |
main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment