Skip to content

Instantly share code, notes, and snippets.

@rday
Created March 4, 2013 02:36
Show Gist options
  • Save rday/5079515 to your computer and use it in GitHub Desktop.
Save rday/5079515 to your computer and use it in GitHub Desktop.
TATransform example for handling two talib functions, SMA and BBands.
import numpy
from talib.abstract import Function
from numbers import Number
from collections import defaultdict
from zipline.transforms.utils import EventWindow, TransformMeta
class TATransform(object):
"""
"""
__metaclass__ = TransformMeta
def __init__(self, fields='price',
market_aware=True, window_length=None, delta=None,
transform_type=None, transform_opts={}, data=None):
if isinstance(fields, basestring):
fields = [fields]
self.fields = fields
self.market_aware = market_aware
self.delta = delta
self.window_length = window_length
#assert transform_type in talib.__dict__
#self.transform_type = talib.__dict__[transform_type]
self.transform_type = transform_type
self.transform_opts = transform_opts
self.data = data
# Market-aware mode only works with full-day windows.
if self.market_aware:
assert self.window_length and not self.delta,\
"Market-aware mode only works with full-day windows."
# Non-market-aware mode requires a timedelta.
else:
assert self.delta and not self.window_length, \
"Non-market-aware mode requires a timedelta."
# No way to pass arguments to the defaultdict factory, so we
# need to define a method to generate the correct EventWindows.
self.sid_windows = defaultdict(self.create_window)
def create_window(self):
"""
Factory method for self.sid_windows.
"""
return TATransformEventWindow(
self.fields,
self.market_aware,
self.window_length,
self.delta,
self.transform_type,
self.data,
self.transform_opts
)
def update(self, event):
"""
Update the event window for this event's sid. Return an ndict
from tracked fields to moving averages.
"""
# This will create a new EventWindow if this is the first
# message for this sid.
window = self.sid_windows[event.sid]
window.update(event)
return window.get_data()
class TransformData(object):
"""
Container for transform data.
"""
def __getitem__(self, name):
"""
Allow dictionary lookup.
"""
return self.__dict__[name]
class TATransformEventWindow(EventWindow):
"""
"""
def __init__(self, fields, market_aware, days, delta, transform_type, data=None, transform_opts={}):
# Call the superclass constructor to set up base EventWindow
# infrastructure.
EventWindow.__init__(self, market_aware, days, delta)
# We maintain a dictionary of totals for each of our tracked
# fields.
self.fields = fields
self.window_data = defaultdict(list)
self.data = data
self.transform_type = transform_type
self.iterations = 0
if data is not None:
for symbol in data:
assert 'open' in data[symbol]
# XXX: TALib still requires this asarray() conversion for the data structure.
# I'm sure somebody more familiar with Pandas than I am could do this better.
inputs = {}
inputs['high'] = numpy.asarray([v for v in data[symbol]['high']])
inputs['open'] = numpy.asarray([v for v in data[symbol]['open']])
inputs['low'] = numpy.asarray([v for v in data[symbol]['low']])
inputs['close'] = numpy.asarray([v for v in data[symbol]['close']])
inputs['volume'] = numpy.asarray([v for v in data[symbol]['volume']])
func = Function(transform_type)
func.set_function_args(inputs, **transform_opts)
# XXX: TALib's func.outputs will be different for some functions. Most
# of the time, it will be a single value per time period (else case).
# Sometimes there will be 3 values per time period (bbands case).
if transform_type == 'BBANDS':
upper, middle, lower = func.outputs
for i in range(len(upper)):
self.window_data[symbol].append((upper[i], middle[i], lower[i]))
else:
self.window_data[symbol] = func.outputs
# Subclass customization for adding new events.
def handle_add(self, event):
# If we were primed with data, don't worry about realtime computation
if self.data is not None:
return
# Sanity check on the event.
self.assert_required_fields(event)
# Increment our running totals with data from the event.
for field in self.fields:
self.window_data[field].append(event[field])
# Subclass customization for removing expired events.
def handle_remove(self, event):
# If we were primed with data, don't worry about realtime computation
if self.data is not None:
return
# Decrement our running totals with data from the event.
for field in self.fields:
self.window_data[field] = self.window_data[field][1:]
def get_data(self):
out = TransformData()
if self.data is not None:
#print "We have primed data, giving value from iteration %d" % self.iterations
for symbol in self.data:
out.__dict__[symbol] = self.window_data[symbol][self.iterations]
self.iterations += 1
return out
for field in self.fields:
out.__dict__[field] = self.transform_type(numpy.array(self.window_data[field]))
return out
def assert_required_fields(self, event):
"""
We only allow events with all of our tracked fields.
"""
for field in self.fields:
assert isinstance(event[field], Number), \
"Got %s for %s in TATransformEventWindow" % (event[field],
field)
"""
Algorithm
def initialize():
# ...
bbands_opts = {'timeperiod': 20, 'matype': talib.MA_Type.EMA, 'price': 'close'}
self.add_transform(TATransform, 'bband', ['price', 'volume'],
window_length=20, data=stock_data,
transform_type='BBANDS',
transform_opts=bbands_opts)
def handle_data():
# ...
for symbol in data.iterkeys():
print symbol
print data[symbol].bband[symbol]
print data[symbol].sma[symbol]
Output:
...
AAPL
(544.64970780971544, 484.14718997781034, 423.64467214590519)
490.838
AAPL
(541.79244699255105, 480.16364807516175, 418.5348491577725)
486.604
AAPL
(538.38445732479624, 478.03758635371776, 417.69071538263927)
483.301...
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment