Last active
December 12, 2015 08:18
-
-
Save rday/4742553 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# Skeleton code, don't use this because it isn't going to work for you | |
""" | |
This is heavily taken from the MovingAverage | |
""" | |
import numpy | |
import talib | |
from numbers import Number | |
from collections import defaultdict | |
from zipline.transforms.utils import EventWindow, TransformMeta | |
class TATransform(object): | |
""" | |
""" | |
__metaclass__ = TransformMeta | |
def __init__(self, fields='price', | |
market_aware=True, window_length=None, delta=None, transform_type=None): | |
if isinstance(fields, basestring): | |
fields = [fields] | |
self.fields = fields | |
self.market_aware = market_aware | |
self.delta = delta | |
self.window_length = window_length | |
assert transform_type in talib.__dict__ | |
self.transform_type = talib.__dict__[transform_type] | |
# Market-aware mode only works with full-day windows. | |
if self.market_aware: | |
assert self.window_length and not self.delta,\ | |
"Market-aware mode only works with full-day windows." | |
# Non-market-aware mode requires a timedelta. | |
else: | |
assert self.delta and not self.window_length, \ | |
"Non-market-aware mode requires a timedelta." | |
# No way to pass arguments to the defaultdict factory, so we | |
# need to define a method to generate the correct EventWindows. | |
self.sid_windows = defaultdict(self.create_window) | |
def create_window(self): | |
""" | |
Factory method for self.sid_windows. | |
""" | |
return TATransformEventWindow( | |
self.fields, | |
self.market_aware, | |
self.window_length, | |
self.delta, | |
self.transform_type | |
) | |
def update(self, event): | |
""" | |
Update the event window for this event's sid. Return an ndict | |
from tracked fields to moving averages. | |
""" | |
# This will create a new EventWindow if this is the first | |
# message for this sid. | |
window = self.sid_windows[event.sid] | |
window.update(event) | |
return window.get_data() | |
class TransformData(object): | |
""" | |
Container for transform data. | |
""" | |
def __getitem__(self, name): | |
""" | |
Allow dictionary lookup. | |
""" | |
return self.__dict__[name] | |
class TATransformEventWindow(EventWindow): | |
""" | |
""" | |
def __init__(self, fields, market_aware, days, delta, transform_type): | |
# Call the superclass constructor to set up base EventWindow | |
# infrastructure. | |
EventWindow.__init__(self, market_aware, days, delta) | |
# We maintain a dictionary of totals for each of our tracked | |
# fields. | |
self.fields = fields | |
self.window_data = defaultdict(list) | |
self.transform_type = transform_type | |
# Subclass customization for adding new events. | |
def handle_add(self, event): | |
# Sanity check on the event. | |
self.assert_required_fields(event) | |
# Increment our running totals with data from the event. | |
for field in self.fields: | |
self.window_data[field].append(event[field]) | |
# Subclass customization for removing expired events. | |
def handle_remove(self, event): | |
# Decrement our running totals with data from the event. | |
for field in self.fields: | |
self.window_data[field] = self.window_data[field][1:] | |
def average(self, field): | |
""" | |
Calculate the average value of our ticks over a single field. | |
""" | |
# Sanity check. | |
assert field in self.fields | |
# Averages are None by convention if we have no ticks. | |
if len(self.ticks) == 0: | |
return 0.0 | |
# Calculate and return the average. len(self.ticks) is O(1). | |
else: | |
return self.totals[field] / len(self.ticks) | |
def get_data(self): | |
""" | |
Return an ndict of all our tracked averages. | |
""" | |
out = TransformData() | |
for field in self.fields: | |
out.__dict__[field] = self.transform_type(numpy.array(self.window_data[field])) | |
return out | |
def assert_required_fields(self, event): | |
""" | |
We only allow events with all of our tracked fields. | |
""" | |
for field in self.fields: | |
assert isinstance(event[field], Number), \ | |
"Got %s for %s in TATransformEventWindow" % (event[field], | |
field) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Zipline transform to link to TA-Lib