Skip to content

Instantly share code, notes, and snippets.

@yhilpisch
Last active November 2, 2024 16:09
Show Gist options
  • Save yhilpisch/f49dd4abd264310ede2bc2e8abfb81c7 to your computer and use it in GitHub Desktop.
Save yhilpisch/f49dd4abd264310ede2bc2e8abfb81c7 to your computer and use it in GitHub Desktop.

Reinforcement Learning for Finance

Workshop at ODSC London 2024

Dr. Yves J. Hilpisch | The Python Quants | CPF Program

London, 06. September 2024

(short link to this Gist: http://bit.ly/odsc_ldn_2024)

Slides

You find the slides at:

http://certificate.tpq.io/odsc_ldn_2024.pdf

Book

You find an early (pre-print) version of my new book at:

https://certificate.tpq.io/rlfinance.html

The book on O'Reilly:

https://learning.oreilly.com/library/view/reinforcement-learning-for/9781098169169/

Resources

This Gist contains selected resources used during the workshop.

Social Media

https://cpf.tpq.io https://x.com/dyjh https://linkedin.com/in/dyjh/ https://github.com/yhilpisch https://youtube.com/c/yves-hilpisch https://bit.ly/quants_dev

Dislaimer

All the content, Python code, Jupyter Notebooks, and other materials (the “Material”) come without warranties or representations, to the extent permitted by applicable law.

None of the Material represents any kind of recommendation or investment advice.

The Material is only meant as a technical illustration.

(c) Dr. Yves J. Hilpisch

Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "475819a4-e148-4616-b1cb-44b659aeb08a",
"metadata": {},
"source": [
"<img src=\"https://hilpisch.com/tpq_logo.png\" alt=\"The Python Quants\" width=\"35%\" align=\"right\" border=\"0\"><br>"
]
},
{
"cell_type": "markdown",
"id": "280cc0c6-2c18-46cd-8af7-3f19b64a6d7e",
"metadata": {},
"source": [
"# Reinforcement Learning for Finance\n",
"\n",
"**Chapter 03 &mdash; Financial Q-Learning**\n",
"\n",
"&copy; Dr. Yves J. Hilpisch\n",
"\n",
"<a href=\"https://tpq.io\" target=\"_blank\">https://tpq.io</a> | <a href=\"https://twitter.com/dyjh\" target=\"_blank\">@dyjh</a> | <a href=\"mailto:[email protected]\">[email protected]</a>"
]
},
{
"cell_type": "markdown",
"id": "d6be6f8b-e00e-402c-9df1-1d3f16e76c7e",
"metadata": {},
"source": [
"## Finance Environment"
]
},
{
"cell_type": "raw",
"id": "e11c4482-fae4-4485-8009-af5905a4e350",
"metadata": {},
"source": [
"# tag::01[]"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "f2c8cd7e-d93d-4c4d-ba77-3c0cb7b677af",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import random"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "bd8d3cf4-c30c-432a-bd3f-23e98c4d201c",
"metadata": {},
"outputs": [],
"source": [
"random.seed(100)\n",
"os.environ['PYTHONHASHSEED'] = '0'"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "cb33cd0c-4fb1-4456-911f-0d92597db8c0",
"metadata": {},
"outputs": [],
"source": [
"class ActionSpace:\n",
" def sample(self):\n",
" return random.randint(0, 1)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "30d49bdd-e24b-4d87-a4dc-5639cc172f8e",
"metadata": {},
"outputs": [],
"source": [
"action_space = ActionSpace()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "416ce315-16d7-4c47-845a-f21a099b8ba3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[0, 1, 1, 0, 1, 1, 1, 0, 0, 0]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[action_space.sample() for _ in range(10)]"
]
},
{
"cell_type": "raw",
"id": "e4f06fae-938f-4db7-b474-b3ad79986ae4",
"metadata": {},
"source": [
"# end::01[]"
]
},
{
"cell_type": "raw",
"id": "6cfbea7a-6223-4a7d-82e4-9b2a74d4faf9",
"metadata": {},
"source": [
"# tag::02[]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f4df457f-9014-4e6a-878a-23645c77037d",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "952353e1-8f39-48ac-ac6d-5a21b9a44315",
"metadata": {},
"outputs": [],
"source": [
"class Finance:\n",
" url = 'https://certificate.tpq.io/rl4finance.csv' # <1>\n",
" def __init__(self, symbol, feature,\n",
" min_accuracy=0.485, n_features=4):\n",
" self.symbol = symbol # <2>\n",
" self.feature = feature # <3>\n",
" self.n_features = n_features # <4>\n",
" self.action_space = ActionSpace() # <5>\n",
" self.min_accuracy = min_accuracy # <6>\n",
" self._get_data() # <7>\n",
" self._prepare_data() # <8>\n",
" def _get_data(self):\n",
" self.raw = pd.read_csv(self.url,\n",
" index_col=0, parse_dates=True) # <7>"
]
},
{
"cell_type": "raw",
"id": "3c3f5df0-c8af-46f3-8e3c-a944fb858a3c",
"metadata": {},
"source": [
"# end::02[]"
]
},
{
"cell_type": "raw",
"id": "a26425a0-8c93-48e2-8671-5b1bb7d130d2",
"metadata": {},
"source": [
"# tag::03[]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "69e1ed75-1e55-42f4-86a3-db54c60acf1f",
"metadata": {},
"outputs": [],
"source": [
"class Finance(Finance):\n",
" def _prepare_data(self):\n",
" self.data = pd.DataFrame(self.raw[self.symbol]).dropna() # <1>\n",
" self.data['r'] = np.log(self.data / self.data.shift(1)) # <2>\n",
" self.data['d'] = np.where(self.data['r'] > 0, 1, 0) # <3>\n",
" self.data.dropna(inplace=True) # <4>\n",
" self.data_ = (self.data - self.data.mean()) / self.data.std() # <5>\n",
" def reset(self):\n",
" self.bar = self.n_features # <6>\n",
" self.treward = 0 # <7>\n",
" state = self.data_[self.feature].iloc[\n",
" self.bar - self.n_features:self.bar].values # <8>\n",
" return state, {}"
]
},
{
"cell_type": "raw",
"id": "a31a1b7b-e112-46e9-abb8-98a80714ac2e",
"metadata": {},
"source": [
"# end::03[]"
]
},
{
"cell_type": "raw",
"id": "e4a7e889-618d-405a-ba7d-50e0258cb8fa",
"metadata": {},
"source": [
"# tag::04[]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "a2b0ccc6-d8ec-4156-bf7a-30ba263fdde9",
"metadata": {},
"outputs": [],
"source": [
"class Finance(Finance):\n",
" def step(self, action):\n",
" if action == self.data['d'].iloc[self.bar]: # <1>\n",
" correct = True\n",
" else:\n",
" correct = False\n",
" reward = 1 if correct else 0 # <2>\n",
" self.treward += reward # <3>\n",
" self.bar += 1 # <4>\n",
" self.accuracy = self.treward / (self.bar - self.n_features) # <5>\n",
" if self.bar >= len(self.data): # <6>\n",
" done = True\n",
" elif reward == 1: # <7>\n",
" done = False\n",
" elif (self.accuracy < self.min_accuracy) and (self.bar > 15): # <8>\n",
" done = True\n",
" else:\n",
" done = False\n",
" next_state = self.data_[self.feature].iloc[\n",
" self.bar - self.n_features:self.bar].values # <9>\n",
" return next_state, reward, done, False, {}"
]
},
{
"cell_type": "raw",
"id": "d90d759c-3877-45c6-a314-7958b1ac7dba",
"metadata": {},
"source": [
"# end::04[]"
]
},
{
"cell_type": "raw",
"id": "34cac22d-5c3b-499e-af03-f581323bd709",
"metadata": {},
"source": [
"# tag::05[]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "373a0a8c-3b85-4933-8de5-1103d4cc1a6b",
"metadata": {},
"outputs": [],
"source": [
"fin = Finance(symbol='EUR=', feature='EUR=') # <1>"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "d4c4248b-2168-42d2-b766-27270681b5dd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['AAPL.O',\n",
" 'MSFT.O',\n",
" 'INTC.O',\n",
" 'AMZN.O',\n",
" 'GS.N',\n",
" '.SPX',\n",
" '.VIX',\n",
" 'SPY',\n",
" 'EUR=',\n",
" 'XAU=',\n",
" 'GDX',\n",
" 'GLD']"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"list(fin.raw.columns) # <2>"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "0c2042dd-3d9a-4976-bb6d-d58daeeaf650",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([2.74844931, 2.64643904, 2.69560062, 2.68085214]), {})"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fin.reset()\n",
"# four lagged, normalized price points"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "d0e04a87-7f63-4532-8609-2ad598d67067",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fin.action_space.sample()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "2c6a11b6-87da-4226-baad-0fa9f4942c44",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([2.64643904, 2.69560062, 2.68085214, 2.63046153]), 0, False, False, {})"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fin.step(fin.action_space.sample())"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "c0a3b905-2eea-406f-9bee-bb61d6f5e463",
"metadata": {},
"outputs": [],
"source": [
"fin = Finance('EUR=', 'r') # <3>"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "c490647f-9757-46bf-911d-c53477d9b3d0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([-1.19130476, -1.21344494, 0.61099805, -0.16094865]), {})"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fin.reset()\n",
"# four lagged, normalized log returns"
]
},
{
"cell_type": "raw",
"id": "b102da68-1f69-42aa-bbe8-d1d264610429",
"metadata": {},
"source": [
"# end::05[]"
]
},
{
"cell_type": "raw",
"id": "4ebe6007-d6e4-4e7e-bb7e-760a024911cd",
"metadata": {},
"source": [
"# tag::06[]"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "1c0bab87-6d45-4e17-a52c-3d19273bd804",
"metadata": {},
"outputs": [],
"source": [
"class RandomAgent:\n",
" def __init__(self):\n",
" self.env = Finance('EUR=', 'r')\n",
" def play(self, episodes=1):\n",
" self.trewards = list()\n",
" for e in range(episodes):\n",
" self.env.reset()\n",
" for step in range(1, 100):\n",
" a = self.env.action_space.sample()\n",
" state, reward, done, trunc, info = self.env.step(a)\n",
" if done:\n",
" self.trewards.append(step)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "417b3f00-199f-4db7-b500-b7b7f99ce15b",
"metadata": {},
"outputs": [],
"source": [
"ra = RandomAgent()"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "99850e42-8c2b-46a6-9a92-59a0e5940061",
"metadata": {},
"outputs": [],
"source": [
"ra.play(15)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "1a6351f5-e532-4703-ae3b-0f7ec2483f48",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[17, 13, 17, 12, 12, 12, 13, 23, 31, 13, 12, 15]"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ra.trewards"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "9590104e-899f-4a4a-81a3-0b952a8f1818",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"15.83"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"round(sum(ra.trewards) / len(ra.trewards), 2) # <1>"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "2252d5e0-0c3f-4900-a96f-1fe6348ccd18",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2607"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(fin.data) # <2>"
]
},
{
"cell_type": "raw",
"id": "c79f2e02-5fae-4dae-b904-02d5ec680bbc",
"metadata": {},
"source": [
"# end::06[]"
]
},
{
"cell_type": "raw",
"id": "8d77c50a-5a97-45c4-8f58-cea31439427b",
"metadata": {},
"source": [
"# tag::07[]"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "06e651e5-4eb4-4001-b8a3-d629721b6eed",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import random\n",
"import warnings\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from collections import deque\n",
"from keras.layers import Dense\n",
"from keras.models import Sequential"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "a04e9dcb-5a0c-463b-9714-012a9b8e4093",
"metadata": {},
"outputs": [],
"source": [
"warnings.simplefilter('ignore')\n",
"os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "c047b3c4-d7ca-4e17-b290-6dfce70690fc",
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.python.framework.ops import disable_eager_execution\n",
"disable_eager_execution()"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "9c5656a5-7378-494b-a43f-5ba736105485",
"metadata": {},
"outputs": [],
"source": [
"opt = keras.optimizers.legacy.Adam(learning_rate=0.0001)"
]
},
{
"cell_type": "raw",
"id": "066ca540-a76a-44fe-9c2d-64b150b292d4",
"metadata": {},
"source": [
"# end::07[]"
]
},
{
"cell_type": "raw",
"id": "bc6d81b9-f4a9-4ad6-914a-e787aa1dde31",
"metadata": {},
"source": [
"# tag::08[]"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "9a1c06c7-6477-4a73-9bf5-68b497c52e8c",
"metadata": {},
"outputs": [],
"source": [
"class DQLAgent:\n",
" def __init__(self, symbol, feature, min_accuracy, n_features=4):\n",
" self.epsilon = 1.0\n",
" self.epsilon_decay = 0.9975\n",
" self.epsilon_min = 0.1\n",
" self.memory = deque(maxlen=2000)\n",
" self.batch_size = 32\n",
" self.gamma = 0.5\n",
" self.trewards = list()\n",
" self.max_treward = 0\n",
" self.n_features = n_features\n",
" self._create_model()\n",
" self.env = Finance(symbol, feature,\n",
" min_accuracy, n_features) # <1>\n",
" def _create_model(self):\n",
" self.model = Sequential()\n",
" self.model.add(Dense(24, activation='relu',\n",
" input_dim=self.n_features))\n",
" self.model.add(Dense(24, activation='relu'))\n",
" self.model.add(Dense(2, activation='linear'))\n",
" self.model.compile(loss='mse', optimizer=opt)\n",
" def act(self, state):\n",
" if random.random() < self.epsilon:\n",
" return self.env.action_space.sample()\n",
" return np.argmax(self.model.predict(state)[0])\n",
" def replay(self):\n",
" batch = random.sample(self.memory, self.batch_size)\n",
" for state, action, next_state, reward, done in batch:\n",
" if not done:\n",
" reward += self.gamma * np.amax(\n",
" self.model.predict(next_state)[0])\n",
" target = self.model.predict(state)\n",
" target[0, action] = reward\n",
" self.model.fit(state, target, epochs=1, verbose=False)\n",
" if self.epsilon > self.epsilon_min:\n",
" self.epsilon *= self.epsilon_decay\n",
" def learn(self, episodes):\n",
" for e in range(1, episodes + 1):\n",
" state, _ = self.env.reset()\n",
" state = np.reshape(state, [1, self.n_features])\n",
" for f in range(1, 5000):\n",
" action = self.act(state)\n",
" next_state, reward, done, trunc, _ = \\\n",
" self.env.step(action)\n",
" next_state = np.reshape(next_state,\n",
" [1, self.n_features])\n",
" self.memory.append(\n",
" [state, action, next_state, reward, done])\n",
" state = next_state \n",
" if done:\n",
" self.trewards.append(f)\n",
" self.max_treward = max(self.max_treward, f)\n",
" templ = f'episode={e:4d} | treward={f:4d}'\n",
" templ += f' | max={self.max_treward:4d}'\n",
" print(templ, end='\\r')\n",
" break\n",
" if len(self.memory) > self.batch_size:\n",
" self.replay()\n",
" print()\n",
" def test(self, episodes):\n",
" ma = self.env.min_accuracy # <2>\n",
" self.env.min_accuracy = 0.5 # <3>\n",
" for e in range(1, episodes + 1):\n",
" state, _ = self.env.reset()\n",
" state = np.reshape(state, [1, self.n_features])\n",
" for f in range(1, 5001):\n",
" action = np.argmax(self.model.predict(state)[0])\n",
" state, reward, done, trunc, _ = self.env.step(action)\n",
" state = np.reshape(state, [1, self.n_features])\n",
" if done:\n",
" tmpl = f'total reward={f} | '\n",
" tmpl += f'accuracy={self.env.accuracy:.3f}'\n",
" print(tmpl)\n",
" break\n",
" self.env.min_accuracy = ma # <2>"
]
},
{
"cell_type": "raw",
"id": "25315e20-127f-4661-a3c7-7fbc01ff4cd1",
"metadata": {},
"source": [
"# end::08[]"
]
},
{
"cell_type": "raw",
"id": "35bc9b01-777e-4bc7-a716-f10de251020e",
"metadata": {},
"source": [
"# tag::09[]"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "d83cf567-0389-474d-accd-38431edaf755",
"metadata": {},
"outputs": [],
"source": [
"random.seed(250)\n",
"tf.random.set_seed(250)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "268f6f90-082d-4827-bdef-8bffa57016c7",
"metadata": {},
"outputs": [],
"source": [
"agent = DQLAgent('EUR=', 'r', 0.495, 4)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "ae2336af-de7e-4b3a-8ecd-292a06a0beb4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"episode= 250 | treward= 12 | max=2603\n",
"CPU times: user 14.6 s, sys: 1.67 s, total: 16.2 s\n",
"Wall time: 13.7 s\n"
]
}
],
"source": [
"%time agent.learn(250)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "6a1023a5-07ef-4ac3-86c4-307a356cd2ba",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"total reward=2603 | accuracy=0.525\n",
"total reward=2603 | accuracy=0.525\n",
"total reward=2603 | accuracy=0.525\n",
"total reward=2603 | accuracy=0.525\n",
"total reward=2603 | accuracy=0.525\n"
]
}
],
"source": [
"agent.test(5) # <1>"
]
},
{
"cell_type": "raw",
"id": "e4f924a1-7f2b-4a77-af99-c549ccc3f35c",
"metadata": {},
"source": [
"# end::09[]"
]
},
{
"cell_type": "markdown",
"id": "20e3eaa7-ac35-44e5-bffc-93662c2d2c55",
"metadata": {},
"source": [
"<img src=\"https://hilpisch.com/tpq_logo.png\" alt=\"The Python Quants\" width=\"35%\" align=\"right\" border=\"0\"><br>\n",
"\n",
"<a href=\"https://tpq.io\" target=\"_blank\">https://tpq.io</a> | <a href=\"https://twitter.com/dyjh\" target=\"_blank\">@dyjh</a> | <a href=\"mailto:[email protected]\">[email protected]</a>"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
#
# Deep Q-Learning Agent
#
# (c) Dr. Yves J. Hilpisch
# Reinforcement Learning for Finance
#
import os
import random
import warnings
import numpy as np
import tensorflow as tf
from tensorflow import keras
from collections import deque
from keras.layers import Dense, Flatten
from keras.models import Sequential
warnings.simplefilter('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from tensorflow.python.framework.ops import disable_eager_execution
disable_eager_execution()
opt = keras.optimizers.legacy.Adam
class DQLAgent:
def __init__(self, symbol, feature, n_features, env, hu=24, lr=0.001):
self.epsilon = 1.0
self.epsilon_decay = 0.9975
self.epsilon_min = 0.1
self.memory = deque(maxlen=2000)
self.batch_size = 32
self.gamma = 0.5
self.trewards = list()
self.max_treward = -np.inf
self.n_features = n_features
self.env = env
self.episodes = 0
self._create_model(hu, lr)
def _create_model(self, hu, lr):
self.model = Sequential()
self.model.add(Dense(hu, activation='relu',
input_dim=self.n_features))
self.model.add(Dense(hu, activation='relu'))
self.model.add(Dense(2, activation='linear'))
self.model.compile(loss='mse', optimizer=opt(learning_rate=lr))
def _reshape(self, state):
state = state.flatten()
return np.reshape(state, [1, len(state)])
def act(self, state):
if random.random() < self.epsilon:
return self.env.action_space.sample()
return np.argmax(self.model.predict(state)[0])
def replay(self):
batch = random.sample(self.memory, self.batch_size)
for state, action, next_state, reward, done in batch:
if not done:
reward += self.gamma * np.amax(
self.model.predict(next_state)[0])
target = self.model.predict(state)
target[0, action] = reward
self.model.fit(state, target, epochs=1, verbose=False)
if self.epsilon > self.epsilon_min:
self.epsilon *= self.epsilon_decay
def learn(self, episodes):
for e in range(1, episodes + 1):
self.episodes += 1
state, _ = self.env.reset()
state = self._reshape(state)
treward = 0
for f in range(1, 5000):
self.f = f
action = self.act(state)
next_state, reward, done, trunc, _ = self.env.step(action)
treward += reward
next_state = self._reshape(next_state)
self.memory.append(
[state, action, next_state, reward, done])
state = next_state
if done:
self.trewards.append(treward)
self.max_treward = max(self.max_treward, treward)
templ = f'episode={self.episodes:4d} | '
templ += f'treward={treward:7.3f}'
templ += f' | max={self.max_treward:7.3f}'
print(templ, end='\r')
break
if len(self.memory) > self.batch_size:
self.replay()
print()
def test(self, episodes, min_accuracy=0.0,
min_performance=0.0, verbose=True,
full=True):
ma = self.env.min_accuracy
self.env.min_accuracy = min_accuracy
if hasattr(self.env, 'min_performance'):
mp = self.env.min_performance
self.env.min_performance = min_performance
self.performances = list()
for e in range(1, episodes + 1):
state, _ = self.env.reset()
state = self._reshape(state)
for f in range(1, 5001):
action = np.argmax(self.model.predict(state)[0])
state, reward, done, trunc, _ = self.env.step(action)
state = self._reshape(state)
if done:
templ = f'total reward={f:4d} | '
templ += f'accuracy={self.env.accuracy:.3f}'
if hasattr(self.env, 'min_performance'):
self.performances.append(self.env.performance)
templ += f' | performance={self.env.performance:.3f}'
if verbose:
if full:
print(templ)
else:
print(templ, end='\r')
break
self.env.min_accuracy = ma
if hasattr(self.env, 'min_performance'):
self.env.min_performance = mp
print()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment