Skip to content

Instantly share code, notes, and snippets.

@buswedg
Created April 12, 2020 03:58
Show Gist options
  • Save buswedg/bbb738211e0d26ef68556cb5322ea3c6 to your computer and use it in GitHub Desktop.
Save buswedg/bbb738211e0d26ef68556cb5322ea3c6 to your computer and use it in GitHub Desktop.
reinforcement_learning_for_share_trading\q_learner
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import random as rand\n",
"\n",
"import numpy as np\n",
"\n",
"from indicators import *\n",
"from col_refs import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class QLearner(object):\n",
"\n",
" def __init__(self, num_states=100, num_actions=4, alpha=0.2, gamma=0.9, rar=0.98, radr=0.999, dyna=0,\n",
" verbose=False):\n",
"\n",
" self.alpha = alpha\n",
" self.gamma = gamma\n",
" self.rar = rar\n",
" self.radr = radr\n",
" self.dyna = dyna\n",
" self.verbose = verbose\n",
"\n",
" self.s = 0\n",
" self.a = 0\n",
"\n",
" self.Q = np.zeros(shape=(num_states, num_actions))\n",
" self.R = np.zeros(shape=(num_states, num_actions))\n",
"\n",
" self.T = np.zeros((num_states, num_actions, num_states))\n",
" self.Tc = np.zeros((num_states, num_actions, num_states))\n",
"\n",
" self.num_actions = num_actions\n",
" self.num_states = num_states\n",
"\n",
" def querysetstate(self, s):\n",
" rand.seed(0)\n",
" np.random.seed(0)\n",
"\n",
" if np.random.uniform() < self.rar:\n",
" action = rand.randint(0, self.num_actions - 1)\n",
"\n",
" else:\n",
" action = self.Q[s, :].argmax()\n",
"\n",
" self.s = s\n",
" self.a = action\n",
"\n",
" if self.verbose: print(\"s =\", s, \"a =\", action)\n",
"\n",
" return action\n",
"\n",
" def query(self, s_prime, r):\n",
" rand.seed(0)\n",
" np.random.seed(0)\n",
"\n",
" if np.random.uniform() < self.rar:\n",
" action = rand.randint(0, self.num_actions - 1)\n",
"\n",
" else:\n",
" action = self.Q[s_prime, :].argmax()\n",
"\n",
" r_fut = self.Q[s_prime, self.Q[s_prime, :].argmax()]\n",
"\n",
" self.Q[self.s, self.a] = (1 - self.alpha) * self.Q[self.s, self.a] + self.alpha * (r + (self.gamma * r_fut))\n",
"\n",
" self.rar = self.rar * self.radr\n",
"\n",
" if self.dyna != 0:\n",
" self.model_update(self.s, self.a, s_prime, r)\n",
"\n",
" self.hallucinate()\n",
"\n",
" self.s = s_prime\n",
" self.a = action\n",
"\n",
" if self.verbose: print(\"s =\", s_prime, \"a =\", action, \"r =\", r)\n",
"\n",
" return action\n",
"\n",
" def model_update(self, s, a, s_prime, r):\n",
" self.Tc[s, a, s_prime] += 1\n",
" self.T = self.Tc / self.Tc.sum(axis=2, keepdims=True)\n",
" self.R[s, a] = ((1 - self.alpha) * self.R[s, a]) + (self.alpha * r)\n",
"\n",
" def hallucinate(self):\n",
" rand.seed(0)\n",
" np.random.seed(0)\n",
"\n",
" for i in range(0, self.dyna):\n",
" s_rnd = rand.randint(0, self.num_states - 1)\n",
" a_rnd = rand.randint(0, self.num_actions - 1)\n",
"\n",
" s_prime = np.random.multinomial(100, self.T[s_rnd, a_rnd, :]).argmax()\n",
"\n",
" r_rnd = self.R[s_rnd, a_rnd]\n",
" r_fut = self.Q[s_prime, self.Q[s_prime, :].argmax()]\n",
"\n",
" self.Q[s_rnd, a_rnd] = (1 - self.alpha) * self.Q[s_rnd, a_rnd] + self.alpha * (r_rnd + (self.gamma * r_fut))\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment