Skip to content

Instantly share code, notes, and snippets.

@rday
Created February 9, 2013 13:50
Show Gist options
  • Save rday/4745333 to your computer and use it in GitHub Desktop.
Save rday/4745333 to your computer and use it in GitHub Desktop.
TA Transform take#2
'''
As usual, this is a wallofcode beta, don't use this in real life
'''
import numpy
import talib
from talib.abstract import Function
from numbers import Number
from collections import defaultdict
from zipline.transforms.utils import EventWindow, TransformMeta
class TATransform(object):
"""
"""
__metaclass__ = TransformMeta
def __init__(self, fields='price',
market_aware=True, window_length=None, delta=None,
transform_type=None, data=None):
if isinstance(fields, basestring):
fields = [fields]
self.fields = fields
self.market_aware = market_aware
self.delta = delta
self.window_length = window_length
#assert transform_type in talib.__dict__
#self.transform_type = talib.__dict__[transform_type]
self.transform_type = transform_type
self.data = data
# Market-aware mode only works with full-day windows.
if self.market_aware:
assert self.window_length and not self.delta,\
"Market-aware mode only works with full-day windows."
# Non-market-aware mode requires a timedelta.
else:
assert self.delta and not self.window_length, \
"Non-market-aware mode requires a timedelta."
# No way to pass arguments to the defaultdict factory, so we
# need to define a method to generate the correct EventWindows.
self.sid_windows = defaultdict(self.create_window)
def create_window(self):
"""
Factory method for self.sid_windows.
"""
return TATransformEventWindow(
self.fields,
self.market_aware,
self.window_length,
self.delta,
self.transform_type,
self.data
)
def update(self, event):
"""
Update the event window for this event's sid. Return an ndict
from tracked fields to moving averages.
"""
# This will create a new EventWindow if this is the first
# message for this sid.
window = self.sid_windows[event.sid]
window.update(event)
return window.get_data()
class TransformData(object):
"""
Container for transform data.
"""
def __getitem__(self, name):
"""
Allow dictionary lookup.
"""
return self.__dict__[name]
class TATransformEventWindow(EventWindow):
"""
"""
def __init__(self, fields, market_aware, days, delta, transform_type, data=None):
# Call the superclass constructor to set up base EventWindow
# infrastructure.
EventWindow.__init__(self, market_aware, days, delta)
# We maintain a dictionary of totals for each of our tracked
# fields.
self.fields = fields
self.window_data = defaultdict(list)
self.data = data
self.transform_type = transform_type
self.iterations = 0
if data is not None:
for symbol in data:
for field in self.fields:
inputs = {}
'''
The abstract function requires these values in the input dict (am I correct?)
but at this point in code flow, we only have 1 value that the user wants to
track. So we fake the rest
'''
array = numpy.asarray([v for v in data[symbol]])
inputs['high'] = array
inputs['open'] = array
inputs['low'] = array
inputs['close'] = array
inputs['volume'] = array
func = Function(transform_type)
# This would be very simple to pass through just like the transform_type
func.set_function_args(inputs, timeperiod=20, matype=talib.MA_Type.EMA, price='open')
'''
Here is where my abstraction starts to fail. Each calculation gives different
data that interests us, so we may have to address each function specifically.
'''
if transform_type == 'BBANDS':
# This will give us (upper, middle, lower) is each day that is
# individually requested by the framework
upper, middle, lower = func.outputs
for i in range(len(upper)):
self.window_data[field].append((upper[i], middle[i], lower[i]))
else:
# This case works for EMA, SMA, etc... but not every other case
self.window_data[field] = func.outputs
# Subclass customization for adding new events.
def handle_add(self, event):
# If we were primed with data, don't worry about realtime computation
if self.data is not None:
return
# Sanity check on the event.
self.assert_required_fields(event)
# Increment our running totals with data from the event.
for field in self.fields:
self.window_data[field].append(event[field])
# Subclass customization for removing expired events.
def handle_remove(self, event):
# If we were primed with data, don't worry about realtime computation
if self.data is not None:
return
# Decrement our running totals with data from the event.
for field in self.fields:
self.window_data[field] = self.window_data[field][1:]
def get_data(self):
out = TransformData()
if self.data is not None:
#print "We have primed data, giving value from iteration %d" % self.iterations
for field in self.fields:
out.__dict__[field] = self.window_data[field][self.iterations]
self.iterations += 1
return out
for field in self.fields:
out.__dict__[field] = self.transform_type(numpy.array(self.window_data[field]))
return out
def assert_required_fields(self, event):
"""
We only allow events with all of our tracked fields.
"""
for field in self.fields:
assert isinstance(event[field], Number), \
"Got %s for %s in TATransformEventWindow" % (event[field],
field)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment