Skip to content

Instantly share code, notes, and snippets.

@hanjae-jea
Created July 7, 2020 08:50
Show Gist options
  • Save hanjae-jea/be6c1bc093b07abb03988299141048f0 to your computer and use it in GitHub Desktop.
Save hanjae-jea/be6c1bc093b07abb03988299141048f0 to your computer and use it in GitHub Desktop.
import pickle
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from datetime import datetime, timedelta
class Backtest:
def __init__(self):
f = open('universe.csv', 'r')
self.codes = f.read().split('\n')
f.close()
self.portfolio = {'cash': {'open': 10000000000, 'quantity': 10000000000}}
self.pp_date = []
self.pp_eval = []
self.price_dict = {}
for code in self.codes:
with open(f'data/{code}.pickle' , 'rb') as f:
self.price_dict[code] = pickle.load(f)
def _price_bound(self, code, dt, tm):
l, r = 0, len(self.price_dict[code]['d'])
while l < r:
m = (l + r) // 2
if (dt, tm) == (self.price_dict[code]['d'][m], self.price_dict[code]['t'][m]):
break
elif dt * 10000 + tm > self.price_dict[code]['d'][m] * 10000 + self.price_dict[code]['t'][m]:
r = (l + r) // 2
else:
l = (l + r) // 2 + 1
return (l+r) // 2
def run(self):
dt = datetime(2016, 7, 7, 9, 0, 0)
while dt < datetime(2020, 7, 7):
if 9 <= dt.hour <= 14 or dt.hour == 15 and dt.minute <= 30:
self.day = int(dt.strftime("%Y%m%d"))
self.time = int(dt.strftime('%H%M'))
self.tick(self.day, self.time)
dt = dt + timedelta(minutes=5)
evl = 0
for code in self.portfolio:
if code is 'cash':
evl += self.portfolio['cash']['open']
else:
evl += self.price(code)['c'] * self.portfolio[code]['quantity']
self.pp_date.append(dt)
self.pp_eval.append(evl)
# 'o': 시가
# 'h': 고가
# 'l': 저가
# 'c': 종가
def price(self, code, index=0):
if code is 'cash':
return {'o': 1, 'h': 1, 'l': 1, 'c': 1}
idx = self._price_bound(code, self.day, self.time) - index
return {'o': self.price_dict[code]['o'][idx],
'h': self.price_dict[code]['h'][idx],
'l': self.price_dict[code]['l'][idx],
'c': self.price_dict[code]['c'][idx],
'v': self.price_dict[code]['v'][idx]}
# 구매에 실패하면 -1
def buy(self, code, percent=0, quantity=0):
if percent > 0 and quantity > 0:
return -1
if percent > 0:
open_price = 0
for cd in self.portfolio:
open_price = open_price + self.portfolio[cd]['open']
quantity = (open_price * percent // 100) // self.price(code)['c']
cash = self.portfolio['cash']['open']
if cash < self.price(code)['c'] * quantity:
return -1
self.portfolio['cash']['open'] = cash - self.price(code)['c'] * quantity
self.portfolio['cash']['quantity'] = self.portfolio['cash']['open']
if code in self.portfolio:
self.portfolio[code]['open'] = self.portfolio[code]['open'] + self.price(code)['c'] * quantity
self.portfolio[code]['quantity'] = self.portfolio[code]['quantity'] + quantity
return 0
else:
self.portfolio[code] = {
'open': self.price(code)['c'] * quantity,
'quantity': quantity
}
return 0
def sell(self, code, percent=0, quantity=0):
if percent > 0 and quantity > 0:
return -1
if code not in self.portfolio:
return -1
if percent > 0:
quantity = self.portfolio[code]['quantity'] * percent // 100
self.portfolio['cash']['open'] = self.portfolio['cash']['open'] + quantity * self.price(code)['c']
self.portfolio['cash']['quantity'] = self.portfolio['cash']['open']
if self.portfolio[code]['quantity'] == quantity:
del self.portfolio[code]
return 0
self.portfolio[code]['open'] = self.portfolio[code]['open'] - (self.price[code]['c'] * quantity)
self.portfolio[code]['quantity'] = self.portfolio[code]['quantity'] - self.portfolio[code]['quantity'] * percent // 100
return 0
def ma(self, code, period=14):
f = self._price_bound(code, self.day, self.time)
s = f - period + 1
if s < 0:
s = 0
return sum(self.price_dict[code]['c'][s: f+1]) / period
def report(self):
# pp_cum = list(map(lambda x: (x - 10000000000) / 10000000000, self.pp_eval))
''' fig, ax = plt.subplots()
ax.plot('date', 'cummlative_return', data=pp_cum)
ax.xaxis.set_major_locator(mdates.YearLocator())
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y%m'))
ax.xaxis.set_minor_locator(mdates.MonthLocator())
ax.grid(True)
fig.autofmt_xdate()
plt.show() '''
plt.plot(self.pp_date, self.pp_eval)
plt.show()
# day: 20190830
# time: 1015
def tick(self, day: int, time: int):
if time == 900:
self.buy('A010820', quantity=10)
# self.portfolio['cash']['quantity'] 현금보유량
# self.portfolio['A010820']['open']
self.portfolio['A010820']['quantity'] * self.price('A010820', 5)['c']
if __name__ == '__main__':
b = Backtest()
b.run()
b.report()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment