Created
November 2, 2024 15:23
-
-
Save tommyip/7f1173dc6a74f13e45ac48cff843ee9b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
from functools import wraps | |
from math import ceil | |
from typing import Callable | |
import lightgbm as lgb | |
import numpy as np | |
import polars as pl | |
from tqdm.notebook import tqdm | |
features = pl.selectors.matches("^feature_") | |
def r2_score(y_true, y_pred, weight): | |
num = np.sum(weight * np.square(y_true - y_pred)) | |
denom = np.sum(weight * np.square(y_true)) | |
return 1 - num / denom | |
def lgb_eval(y_pred: np.ndarray, eval_data: lgb.Dataset): | |
y_true = eval_data.get_label() | |
weight = eval_data.get_weight() | |
return "r2", r2_score(y_true, y_pred, weight), True | |
def n_days(lf: pl.LazyFrame) -> int: | |
return lf.select(pl.col("date_id").unique().len()).collect().item() | |
def train_val_split( | |
lf: pl.LazyFrame, *, train_days: int, val_ratio: float | |
) -> tuple[pl.LazyFrame, pl.LazyFrame]: | |
n = n_days(lf) | |
val_days = ceil(n * val_ratio) | |
val_date_id = n - val_days | |
train_date_id = val_date_id - train_days | |
lf_train = lf.filter( | |
(pl.col("date_id") >= train_date_id) & (pl.col("date_id") < val_date_id) | |
) | |
lf_val = lf.filter(pl.col("date_id") >= val_date_id) | |
return lf_train, lf_val | |
def perf_metrics(f): | |
@wraps(f) | |
def wrapper(*args, **kwargs): | |
start = time.perf_counter() | |
out = f(*args, **kwargs) | |
elapsed = time.perf_counter() - start | |
print(f"{f.__name__} took {elapsed:.2f}s") | |
return out | |
return wrapper | |
def dummy_predict(test: pl.DataFrame, lags: pl.DataFrame | None) -> pl.DataFrame: | |
return test.select( | |
"row_id", | |
pl.lit(0.0).alias("responder_6"), | |
) | |
def simulate( | |
lf: pl.LazyFrame, | |
start_date_id: int, | |
predict_fn: Callable[[pl.DataFrame, pl.DataFrame | None], pl.DataFrame], | |
fast: bool = False, | |
) -> pl.DataFrame: | |
"""Simulate the submission API, running predict_fn one (date_id, time_id) at a time. | |
If fast=True, run predict_fn one date_id at a time to reduce overhead, note that this is not the same the submission API!!! | |
Return a dataframe with columns ('date_id', 'time_id', 'symbol_id', 'weight', 'responder_6', 'responder_6_pred') | |
""" | |
group_by = ("date_id",) if fast else ("date_id", "time_id") | |
total_groups = ( | |
lf.filter(pl.col("date_id") >= start_date_id) | |
.unique(group_by) | |
.select(pl.len()) | |
.collect() | |
.item() | |
) | |
with tqdm(total=total_groups) as pbar: | |
def step(df: pl.DataFrame) -> pl.DataFrame: | |
head = df.select("date_id", "time_id").row(0, named=True) | |
lags = None | |
if head["time_id"] == 0: | |
lags = ( | |
lf.filter(pl.col("date_id") == head["date_id"] - 1) | |
.select( | |
"date_id", | |
"time_id", | |
"symbol_id", | |
pl.selectors.matches("responder").name.suffix("_lag_1"), | |
) | |
.collect() | |
) | |
test = df.select( | |
pl.int_range(pl.len(), dtype=pl.UInt32).alias("row_id"), | |
"date_id", | |
"time_id", | |
"symbol_id", | |
"weight", | |
pl.lit(True).alias("is_scored"), | |
pl.selectors.matches("feature"), | |
) | |
pred_start = time.perf_counter() | |
pred = predict_fn(test, lags) | |
pred_elapsed = time.perf_counter() - pred_start | |
if pred_elapsed > 60: | |
print(f"Prediction function took {pred_elapsed:.2f}s!") | |
raise Exception("Inference server internal error") | |
pbar.update(1) | |
return ( | |
df.lazy() | |
.with_row_index("row_id") | |
.join(pred.lazy(), on="row_id", suffix="_pred") | |
.select( | |
"date_id", | |
"time_id", | |
"symbol_id", | |
"weight", | |
"responder_6", | |
"responder_6_pred", | |
) | |
.collect() | |
) | |
return ( | |
lf.filter(pl.col("date_id") >= start_date_id) | |
.group_by(*group_by, maintain_order=True) | |
.map_groups( | |
step, | |
schema={ | |
"date_id": pl.Int64, | |
"time_id": pl.Int64, | |
"symbol_id": pl.String, | |
"weight": pl.Float32, | |
"responder_6": pl.Float32, | |
"responder_6_pred": pl.Float32, | |
}, | |
) | |
.collect() | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment