Created
February 9, 2013 13:50
-
-
Save rday/4745333 to your computer and use it in GitHub Desktop.
TA Transform take#2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
As usual, this is a wallofcode beta, don't use this in real life | |
''' | |
import numpy | |
import talib | |
from talib.abstract import Function | |
from numbers import Number | |
from collections import defaultdict | |
from zipline.transforms.utils import EventWindow, TransformMeta | |
class TATransform(object): | |
""" | |
""" | |
__metaclass__ = TransformMeta | |
def __init__(self, fields='price', | |
market_aware=True, window_length=None, delta=None, | |
transform_type=None, data=None): | |
if isinstance(fields, basestring): | |
fields = [fields] | |
self.fields = fields | |
self.market_aware = market_aware | |
self.delta = delta | |
self.window_length = window_length | |
#assert transform_type in talib.__dict__ | |
#self.transform_type = talib.__dict__[transform_type] | |
self.transform_type = transform_type | |
self.data = data | |
# Market-aware mode only works with full-day windows. | |
if self.market_aware: | |
assert self.window_length and not self.delta,\ | |
"Market-aware mode only works with full-day windows." | |
# Non-market-aware mode requires a timedelta. | |
else: | |
assert self.delta and not self.window_length, \ | |
"Non-market-aware mode requires a timedelta." | |
# No way to pass arguments to the defaultdict factory, so we | |
# need to define a method to generate the correct EventWindows. | |
self.sid_windows = defaultdict(self.create_window) | |
def create_window(self): | |
""" | |
Factory method for self.sid_windows. | |
""" | |
return TATransformEventWindow( | |
self.fields, | |
self.market_aware, | |
self.window_length, | |
self.delta, | |
self.transform_type, | |
self.data | |
) | |
def update(self, event): | |
""" | |
Update the event window for this event's sid. Return an ndict | |
from tracked fields to moving averages. | |
""" | |
# This will create a new EventWindow if this is the first | |
# message for this sid. | |
window = self.sid_windows[event.sid] | |
window.update(event) | |
return window.get_data() | |
class TransformData(object): | |
""" | |
Container for transform data. | |
""" | |
def __getitem__(self, name): | |
""" | |
Allow dictionary lookup. | |
""" | |
return self.__dict__[name] | |
class TATransformEventWindow(EventWindow): | |
""" | |
""" | |
def __init__(self, fields, market_aware, days, delta, transform_type, data=None): | |
# Call the superclass constructor to set up base EventWindow | |
# infrastructure. | |
EventWindow.__init__(self, market_aware, days, delta) | |
# We maintain a dictionary of totals for each of our tracked | |
# fields. | |
self.fields = fields | |
self.window_data = defaultdict(list) | |
self.data = data | |
self.transform_type = transform_type | |
self.iterations = 0 | |
if data is not None: | |
for symbol in data: | |
for field in self.fields: | |
inputs = {} | |
''' | |
The abstract function requires these values in the input dict (am I correct?) | |
but at this point in code flow, we only have 1 value that the user wants to | |
track. So we fake the rest | |
''' | |
array = numpy.asarray([v for v in data[symbol]]) | |
inputs['high'] = array | |
inputs['open'] = array | |
inputs['low'] = array | |
inputs['close'] = array | |
inputs['volume'] = array | |
func = Function(transform_type) | |
# This would be very simple to pass through just like the transform_type | |
func.set_function_args(inputs, timeperiod=20, matype=talib.MA_Type.EMA, price='open') | |
''' | |
Here is where my abstraction starts to fail. Each calculation gives different | |
data that interests us, so we may have to address each function specifically. | |
''' | |
if transform_type == 'BBANDS': | |
# This will give us (upper, middle, lower) is each day that is | |
# individually requested by the framework | |
upper, middle, lower = func.outputs | |
for i in range(len(upper)): | |
self.window_data[field].append((upper[i], middle[i], lower[i])) | |
else: | |
# This case works for EMA, SMA, etc... but not every other case | |
self.window_data[field] = func.outputs | |
# Subclass customization for adding new events. | |
def handle_add(self, event): | |
# If we were primed with data, don't worry about realtime computation | |
if self.data is not None: | |
return | |
# Sanity check on the event. | |
self.assert_required_fields(event) | |
# Increment our running totals with data from the event. | |
for field in self.fields: | |
self.window_data[field].append(event[field]) | |
# Subclass customization for removing expired events. | |
def handle_remove(self, event): | |
# If we were primed with data, don't worry about realtime computation | |
if self.data is not None: | |
return | |
# Decrement our running totals with data from the event. | |
for field in self.fields: | |
self.window_data[field] = self.window_data[field][1:] | |
def get_data(self): | |
out = TransformData() | |
if self.data is not None: | |
#print "We have primed data, giving value from iteration %d" % self.iterations | |
for field in self.fields: | |
out.__dict__[field] = self.window_data[field][self.iterations] | |
self.iterations += 1 | |
return out | |
for field in self.fields: | |
out.__dict__[field] = self.transform_type(numpy.array(self.window_data[field])) | |
return out | |
def assert_required_fields(self, event): | |
""" | |
We only allow events with all of our tracked fields. | |
""" | |
for field in self.fields: | |
assert isinstance(event[field], Number), \ | |
"Got %s for %s in TATransformEventWindow" % (event[field], | |
field) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment