Created
April 16, 2024 21:12
-
-
Save 0x0L/b09a5b5cfb2a0298acbf3a5637b5effd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pathlib import Path | |
import numpy as np | |
import pandas as pd | |
class RaggedIndexer: | |
def __init__(self, counts): | |
self.stops = np.r_[0, np.cumsum(counts)] | |
self.starts = self.stops[:-1] | |
self.n = len(counts) | |
def __getitem__(self, item): | |
if not isinstance(item, slice): | |
start = item | |
stop = slice(start, None).indices(self.n)[0] + 1 | |
else: | |
start, stop, _ = item.indices(self.n) | |
return slice(self.starts[start], self.stops[stop]) | |
class Unstacker: | |
def __init__(self, counts, ids): | |
unique, asset = np.unique(ids, return_inverse=True) | |
self.asset = asset | |
self.ts = np.repeat(np.arange(len(counts)), counts) | |
self.shape = (len(counts), len(unique)) | |
def __call__(self, array): | |
out = np.zeros(self.shape + array.shape[1:], dtype=array.dtype) | |
out[self.ts, self.asset] = array | |
return out | |
class Dataset: | |
def __init__(self, *, timestamps, asset_count, asset_ids): | |
assert len(timestamps) == len(asset_count) | |
assert len(asset_ids) == sum(asset_count) | |
self.ts = timestamps | |
self.asset_ids = asset_ids | |
self.asset_count = asset_count | |
self.n_ts = len(timestamps) | |
self.n_samples = len(self.asset_ids) | |
self._context = {} | |
self._asset_data = {} | |
def multiindex(self): | |
ts = np.repeat(self.ts, self.asset_count) | |
return pd.MultiIndex.from_arrays([ts, self.asset_ids], names=["TS", "ASSET_ID"]) | |
def add_context(self, **arrays): | |
assert all(a.shape[0] == self.n_ts for a in arrays.values()) | |
self._context |= arrays | |
def add_asset_data(self, **arrays): | |
assert all(a.shape[0] == self.n_samples for a in arrays.values()) | |
self._asset_data |= arrays | |
def save(self, path, overwrite=True): | |
path = Path(path) | |
if path.exists() and not overwrite: | |
raise FileExistsError() | |
path.mkdir(exist_ok=overwrite) | |
np.savez(path / "METADATA.npz", timestamps=self.ts, asset_ids=self.asset_ids, asset_count=self.asset_count) | |
for prefix, data in {"context": self._context, "asset_data": self._asset_data}.items(): | |
(path / prefix).mkdir(exist_ok=overwrite) | |
for k, v in data.items(): | |
np.save(path / prefix / (k + ".npy"), v) | |
@staticmethod | |
def load(path): | |
path = Path(path) | |
ds = Dataset(**np.load(path / "METADATA.npz")) | |
read_fn = np.lib.format.open_memmap | |
ds.add_context(**{k.stem: read_fn(k) for k in (path / "context").iterdir()}) | |
ds.add_asset_data(**{k.stem: read_fn(k) for k in (path / "asset_data").iterdir()}) | |
return ds | |
def to_tensors(self): | |
unstack = Unstacker(self.asset_count, self.asset_ids) | |
asset_data = {k: unstack(v) for k, v in self._asset_data.items()} | |
return asset_data | self._context | |
def __repr__(self): | |
if not self.n_samples: | |
return "Dataset (empty)" | |
format = lambda x: pd.to_datetime(x).strftime("%Y-%m-%d") | |
return f"""Dataset: | |
Timestamp = {format(self.ts[0])} / {format(self.ts[-1])} ({len(self.ts)}) | |
Assets = {len(np.unique(self.asset_ids))} / {self.n_samples} samples | |
Context = { {k: (v.shape[1:], str(v.dtype)) for k, v in self._context.items()} } | |
Data = { {k: (v.shape[1:], str(v.dtype)) for k, v in self._asset_data.items()} } | |
""" | |
def __getitem__(self, item): | |
assert isinstance(item, slice) | |
if item.start is not None and not isinstance(item.start, int): | |
item = slice(np.searchsorted(self.ts, np.datetime64(item.start)), item.stop) | |
if item.stop is not None and not isinstance(item.stop, int): | |
item = slice(item.start, np.searchsorted(self.ts, np.datetime64(item.stop))) | |
indexer = RaggedIndexer(self.asset_count) | |
ragged_item = indexer[item] | |
ds = Dataset( | |
timestamps=self.ts[item], | |
asset_count=self.asset_count[item], | |
asset_ids=self.asset_ids[ragged_item], | |
) | |
ds.add_context(**{k: v[item] for k, v in self._context.items()}) | |
ds.add_asset_data(**{k: v[ragged_item] for k, v in self._asset_data.items()}) | |
return ds |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment