Last active
January 22, 2018 02:52
-
-
Save khuangaf/2882c61e7fc02788ea7d0a1519b17902 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DataSrc(object): | |
"""Acts as data provider for each new episode.""" | |
def __init__(self, df, steps=252, scale=True, scale_extra_cols=True, augment=0.00, window_length=50): | |
""" | |
DataSrc. | |
df - csv for data frame index of timestamps | |
and multi-index columns levels=[['LTCBTC'],...],['open','low','high','close',...]] | |
an example is included as an hdf file in this repository | |
steps - total steps in episode | |
scale - scale the data for each episode | |
scale_extra_cols - scale extra columns by global mean and std | |
augment - fraction to augment the data by | |
""" | |
self.steps = steps + 1 | |
self.augment = augment | |
self.scale = scale | |
self.scale_extra_cols = scale_extra_cols | |
self.window_length = window_length | |
# get rid of NaN's | |
df = df.copy() | |
df.replace(np.nan, 0, inplace=True) | |
df = df.fillna(method="pad") | |
# dataframe to matrix | |
self.asset_names = df.columns.levels[0].tolist() | |
self.features = df.columns.levels[1].tolist() | |
data = df.as_matrix().reshape( | |
(len(df), len(self.asset_names), len(self.features))) | |
self._data = np.transpose(data, (1, 0, 2)) | |
self._times = df.index | |
self.price_columns = ['close', 'high', 'low', 'open'] | |
self.non_price_columns = set( | |
df.columns.levels[1]) - set(self.price_columns) | |
# Stats to let us normalize non price columns | |
if scale_extra_cols: | |
x = self._data.reshape((-1, len(self.features))) | |
self.stats = dict(mean=x.mean(0), std=x.std(0)) | |
# for column in self._data.columns.levels[1].tolist(): | |
# x = df.xs(key=column, axis=1, level='Price').as_matrix()[:, :] | |
# self.stats["mean"].append(x.mean()) | |
# = dict(mean=x.mean(), std=x.std()) | |
self.reset() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment