Skip to content

Instantly share code, notes, and snippets.

@rday
Last active December 12, 2015 08:18
Show Gist options
  • Save rday/4742553 to your computer and use it in GitHub Desktop.
Save rday/4742553 to your computer and use it in GitHub Desktop.
#
# Skeleton code, don't use this because it isn't going to work for you
"""
This is heavily taken from the MovingAverage
"""
import numpy
import talib
from numbers import Number
from collections import defaultdict
from zipline.transforms.utils import EventWindow, TransformMeta
class TATransform(object):
"""
"""
__metaclass__ = TransformMeta
def __init__(self, fields='price',
market_aware=True, window_length=None, delta=None, transform_type=None):
if isinstance(fields, basestring):
fields = [fields]
self.fields = fields
self.market_aware = market_aware
self.delta = delta
self.window_length = window_length
assert transform_type in talib.__dict__
self.transform_type = talib.__dict__[transform_type]
# Market-aware mode only works with full-day windows.
if self.market_aware:
assert self.window_length and not self.delta,\
"Market-aware mode only works with full-day windows."
# Non-market-aware mode requires a timedelta.
else:
assert self.delta and not self.window_length, \
"Non-market-aware mode requires a timedelta."
# No way to pass arguments to the defaultdict factory, so we
# need to define a method to generate the correct EventWindows.
self.sid_windows = defaultdict(self.create_window)
def create_window(self):
"""
Factory method for self.sid_windows.
"""
return TATransformEventWindow(
self.fields,
self.market_aware,
self.window_length,
self.delta,
self.transform_type
)
def update(self, event):
"""
Update the event window for this event's sid. Return an ndict
from tracked fields to moving averages.
"""
# This will create a new EventWindow if this is the first
# message for this sid.
window = self.sid_windows[event.sid]
window.update(event)
return window.get_data()
class TransformData(object):
"""
Container for transform data.
"""
def __getitem__(self, name):
"""
Allow dictionary lookup.
"""
return self.__dict__[name]
class TATransformEventWindow(EventWindow):
"""
"""
def __init__(self, fields, market_aware, days, delta, transform_type):
# Call the superclass constructor to set up base EventWindow
# infrastructure.
EventWindow.__init__(self, market_aware, days, delta)
# We maintain a dictionary of totals for each of our tracked
# fields.
self.fields = fields
self.window_data = defaultdict(list)
self.transform_type = transform_type
# Subclass customization for adding new events.
def handle_add(self, event):
# Sanity check on the event.
self.assert_required_fields(event)
# Increment our running totals with data from the event.
for field in self.fields:
self.window_data[field].append(event[field])
# Subclass customization for removing expired events.
def handle_remove(self, event):
# Decrement our running totals with data from the event.
for field in self.fields:
self.window_data[field] = self.window_data[field][1:]
def average(self, field):
"""
Calculate the average value of our ticks over a single field.
"""
# Sanity check.
assert field in self.fields
# Averages are None by convention if we have no ticks.
if len(self.ticks) == 0:
return 0.0
# Calculate and return the average. len(self.ticks) is O(1).
else:
return self.totals[field] / len(self.ticks)
def get_data(self):
"""
Return an ndict of all our tracked averages.
"""
out = TransformData()
for field in self.fields:
out.__dict__[field] = self.transform_type(numpy.array(self.window_data[field]))
return out
def assert_required_fields(self, event):
"""
We only allow events with all of our tracked fields.
"""
for field in self.fields:
assert isinstance(event[field], Number), \
"Got %s for %s in TATransformEventWindow" % (event[field],
field)
@rday
Copy link
Author

rday commented Feb 8, 2013

Zipline transform to link to TA-Lib

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment