Skip to content

Instantly share code, notes, and snippets.

@0x0L
Created April 16, 2024 21:12
Show Gist options
  • Save 0x0L/b09a5b5cfb2a0298acbf3a5637b5effd to your computer and use it in GitHub Desktop.
Save 0x0L/b09a5b5cfb2a0298acbf3a5637b5effd to your computer and use it in GitHub Desktop.
from pathlib import Path
import numpy as np
import pandas as pd
class RaggedIndexer:
def __init__(self, counts):
self.stops = np.r_[0, np.cumsum(counts)]
self.starts = self.stops[:-1]
self.n = len(counts)
def __getitem__(self, item):
if not isinstance(item, slice):
start = item
stop = slice(start, None).indices(self.n)[0] + 1
else:
start, stop, _ = item.indices(self.n)
return slice(self.starts[start], self.stops[stop])
class Unstacker:
def __init__(self, counts, ids):
unique, asset = np.unique(ids, return_inverse=True)
self.asset = asset
self.ts = np.repeat(np.arange(len(counts)), counts)
self.shape = (len(counts), len(unique))
def __call__(self, array):
out = np.zeros(self.shape + array.shape[1:], dtype=array.dtype)
out[self.ts, self.asset] = array
return out
class Dataset:
def __init__(self, *, timestamps, asset_count, asset_ids):
assert len(timestamps) == len(asset_count)
assert len(asset_ids) == sum(asset_count)
self.ts = timestamps
self.asset_ids = asset_ids
self.asset_count = asset_count
self.n_ts = len(timestamps)
self.n_samples = len(self.asset_ids)
self._context = {}
self._asset_data = {}
def multiindex(self):
ts = np.repeat(self.ts, self.asset_count)
return pd.MultiIndex.from_arrays([ts, self.asset_ids], names=["TS", "ASSET_ID"])
def add_context(self, **arrays):
assert all(a.shape[0] == self.n_ts for a in arrays.values())
self._context |= arrays
def add_asset_data(self, **arrays):
assert all(a.shape[0] == self.n_samples for a in arrays.values())
self._asset_data |= arrays
def save(self, path, overwrite=True):
path = Path(path)
if path.exists() and not overwrite:
raise FileExistsError()
path.mkdir(exist_ok=overwrite)
np.savez(path / "METADATA.npz", timestamps=self.ts, asset_ids=self.asset_ids, asset_count=self.asset_count)
for prefix, data in {"context": self._context, "asset_data": self._asset_data}.items():
(path / prefix).mkdir(exist_ok=overwrite)
for k, v in data.items():
np.save(path / prefix / (k + ".npy"), v)
@staticmethod
def load(path):
path = Path(path)
ds = Dataset(**np.load(path / "METADATA.npz"))
read_fn = np.lib.format.open_memmap
ds.add_context(**{k.stem: read_fn(k) for k in (path / "context").iterdir()})
ds.add_asset_data(**{k.stem: read_fn(k) for k in (path / "asset_data").iterdir()})
return ds
def to_tensors(self):
unstack = Unstacker(self.asset_count, self.asset_ids)
asset_data = {k: unstack(v) for k, v in self._asset_data.items()}
return asset_data | self._context
def __repr__(self):
if not self.n_samples:
return "Dataset (empty)"
format = lambda x: pd.to_datetime(x).strftime("%Y-%m-%d")
return f"""Dataset:
Timestamp = {format(self.ts[0])} / {format(self.ts[-1])} ({len(self.ts)})
Assets = {len(np.unique(self.asset_ids))} / {self.n_samples} samples
Context = { {k: (v.shape[1:], str(v.dtype)) for k, v in self._context.items()} }
Data = { {k: (v.shape[1:], str(v.dtype)) for k, v in self._asset_data.items()} }
"""
def __getitem__(self, item):
assert isinstance(item, slice)
if item.start is not None and not isinstance(item.start, int):
item = slice(np.searchsorted(self.ts, np.datetime64(item.start)), item.stop)
if item.stop is not None and not isinstance(item.stop, int):
item = slice(item.start, np.searchsorted(self.ts, np.datetime64(item.stop)))
indexer = RaggedIndexer(self.asset_count)
ragged_item = indexer[item]
ds = Dataset(
timestamps=self.ts[item],
asset_count=self.asset_count[item],
asset_ids=self.asset_ids[ragged_item],
)
ds.add_context(**{k: v[item] for k, v in self._context.items()})
ds.add_asset_data(**{k: v[ragged_item] for k, v in self._asset_data.items()})
return ds
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment