Last active
November 24, 2016 09:53
-
-
Save dat-boris/4b39c6a27a9858d1a8dd582365b9ea2e to your computer and use it in GitHub Desktop.
Pipeline abstraction class for handling quantopian pipeline data
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://www.quantopian.com/posts/help-with-custom-factor | |
class SidInList(CustomFilter): | |
inputs = [] | |
window_length = 1 | |
params = ('sid_list',) | |
def compute(self, today, assets, out, sid_list): | |
assert len(sid_list) > 0 | |
out[:] = np.in1d(assets, sid_list) | |
class PipelineData(): | |
""" | |
A class which provides additional functions for the pipeline data | |
""" | |
def __init__(self, df): | |
self.df = df; | |
def dropna(self, how='all'): | |
self.df = self.df.dropna(how=how) | |
return self | |
@property | |
def first_date(self): | |
""" | |
:return: the first date of the pipeline data | |
""" | |
return min(self.df.index.get_level_values(0)) | |
@property | |
def last_date(self): | |
""" | |
:return: the last date of the pipeline data | |
""" | |
return max(self.df.index.get_level_values(0)) | |
def get_universe(self, df=None): | |
""" | |
:return: the universe provided by the stocks | |
""" | |
if df is None: | |
df = self.df | |
return set(self.df.index.get_level_values(1)) | |
def get_all_rows_of_stock(self, symbol): | |
""" | |
:return: all rows related to a particular symbol | |
""" | |
return self.df.xs(symbol, level=1) | |
def iterdate(self): | |
""" | |
Iterate through the date available | |
""" | |
return set(self.df.index.get_level_values(0)) | |
def get_data_from_date(self, date): | |
return self.df.xs(date, level=0) | |
def get_universe_from_date(self, date): | |
return self.df.xs(date, level=0).index | |
def merge_df(self, date, df): | |
""" | |
Given a df with simlar index, merge! | |
""" | |
#assert len(self.df.ix[date, df.columns]) == len(df) | |
shared_index = set(df.index) & set(self.df.index) | |
#assert len(shared_index) > 0, "WARNING: no shared index found!" | |
if len(shared_index) == 0: | |
print "WARNING: no shared index found!" | |
return | |
self.df.ix[shared_index, df.columns] = df | |
def get_sorted_column(self, colname): | |
""" | |
:return: a list of stocks based on sorted by a column of data | |
""" | |
raise NotImplemented() | |
def get_pricing_delta(self, day_after): | |
""" | |
Add a column price_after_x, delta_after_x to the dataframe | |
""" | |
def pricing_func(start_date, | |
screen=None, | |
stocks=None, | |
days=None, | |
col_name='price_after_{}'.format(None)): | |
end_date = start_date + timedelta(days=days) | |
returns = get_pricing( | |
stocks, | |
fields=PRICE_USED, | |
start_date=end_date, | |
# ensure that we dont fall on holiday | |
end_date=end_date + timedelta(days=7) | |
) | |
# if there's holiday, remove those | |
# use 'all' since there might be some empty labels | |
returns = returns.dropna(how='all') | |
assert len(returns) > 0 | |
returns = returns.ix[0:1, :] | |
returns.index = [start_date] | |
multi_index = returns.stack() | |
return_df = pd.DataFrame(multi_index, columns=[col_name]) | |
global DEBUG_DF | |
DEBUG_DF = returns | |
return return_df | |
price_col_name = 'price_{}'.format(day_after) | |
delta_col_name = 'delta_{}'.format(day_after) | |
self.enrich(pricing_func, days=0, col_name='price_start') | |
self.enrich(pricing_func, days=day_after, col_name=price_col_name) | |
self.df[delta_col_name] = self.df[price_col_name]/self.df.price_start - 1 | |
return self | |
def enrich(self, pipeline_func, **kwargs): | |
""" | |
Given a dataframe of specific row, try to enrich that dataframe | |
:param: pipeline_func(date, screen=None, stock=None) - get the pipeline data given the xyz | |
""" | |
columns_set = [] | |
for date in self.iterdate(): | |
#https://www.quantopian.com/posts/pipeline-set-screen-by-sid | |
sid_list = tuple([s.sid for s in self.get_universe_from_date(date)]) | |
assert len(sid_list) > 0 | |
print "Enriching {}, {} with {} items: {}".format( | |
date, date.tz, len(self.get_data_from_date(date)), sid_list | |
) | |
stocks_on_date = self.get_data_from_date(date).index | |
# remember to localize to UTC!! | |
new_data = pipeline_func(date, | |
screen=SidInList(sid_list=sid_list), | |
stocks=stocks_on_date, | |
**kwargs | |
) | |
global DEBUG_DF_PARAM | |
DEBUG_DF_PARAM = new_data | |
# Add new columns | |
if len(columns_set) == 0: | |
assert len(new_data.columns) > 0 | |
columns_set = new_data.columns | |
for c in new_data.columns: | |
self.df[c] = np.NaN | |
#self.df.ix[date, columns_set] = new_data | |
self.merge_df(date, new_data) | |
return self | |
# def test_pipeline_data(): | |
# output = get_ceo_change_dates( | |
# pd.Timestamp('2013-01-01', tz=utc), | |
# pd.Timestamp('2013-03-01', tz=utc) | |
# ) | |
# test_df = output.get_pricing_delta(30) | |
# assert len(test_df.df.dropna().price_30) > 0 | |
# return test_df | |
# test_pipeline_data().df |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment