Skip to content

Instantly share code, notes, and snippets.

@serihiro
Last active October 9, 2025 01:42
Show Gist options
  • Select an option

  • Save serihiro/2e45038de2d4dad1b0b7bded5c7bd00b to your computer and use it in GitHub Desktop.

Select an option

Save serihiro/2e45038de2d4dad1b0b7bded5c7bd00b to your computer and use it in GitHub Desktop.
lgbm_inference_benchmark.py
import argparse
import os
import sys
import time
import tempfile
import json
import urllib.request
import shutil
import platform
from pathlib import Path
import inspect
import numpy as np
# 依存ライブラリ
import lightgbm as lgb
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
import treelite
import tl2cgen
def log(msg):
print(f'[INFO] {msg}', flush=True)
def try_download_model(url: str, dst: Path) -> bool:
try:
log(f'Downloading LightGBM model from: {url}')
with urllib.request.urlopen(url, timeout=30) as resp, open(dst, 'wb') as f:
shutil.copyfileobj(resp, f)
log(f'Saved to: {dst}')
return True
except Exception as e:
log(f'Download failed: {e}')
return False
def train_toy_lgbm_model(
n_samples=200_000,
n_features=50,
n_informative=30,
random_state=42,
num_leaves=64,
n_estimators=500,
learning_rate=0.05,
test_size=0.2,
):
log('Training toy LightGBM model on synthetic classification data...')
X, y = make_classification(
n_samples=n_samples,
n_features=n_features,
n_informative=n_informative,
n_redundant=n_features - n_informative,
random_state=random_state,
)
X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=test_size, random_state=random_state)
train_set = lgb.Dataset(X_train, label=y_train)
valid_set = lgb.Dataset(X_valid, label=y_valid)
params = dict(
objective='binary',
metric='auc',
num_leaves=num_leaves,
learning_rate=learning_rate,
verbosity=-1,
)
booster = lgb.train(
params,
train_set,
num_boost_round=n_estimators,
valid_sets=[valid_set],
valid_names=['valid'],
)
# AUCの簡易表示(plot_metricは学習履歴が必要なのでシンプルに算出する)
auc = roc_auc_score(y_valid, booster.predict(X_valid))
log(f'Toy model trained. valid AUC ~ {auc:.4f}')
return booster, (X, y)
def load_lightgbm_model(model_path: Path):
log(f'Loading LightGBM model: {model_path}')
booster = lgb.Booster(model_file=str(model_path))
return booster
def _call_with_supported_kwargs(func, **kwargs):
"""Call func with only the kwargs it supports to handle version differences."""
try:
sig = inspect.signature(func)
except (TypeError, ValueError):
return func(**kwargs)
if any(p.kind == inspect.Parameter.VAR_POSITIONAL for p in sig.parameters.values()):
return func(**kwargs)
filtered = {}
for name, param in sig.parameters.items():
if param.kind == inspect.Parameter.VAR_KEYWORD:
return func(**kwargs)
if name in kwargs:
filtered[name] = kwargs[name]
return func(**filtered)
def _locate_compiled_library(out_dir: Path, libname: str) -> Path:
"""Locate the shared library emitted by TL2cgen, accounting for name variations."""
expected = out_dir / libname
if expected.exists():
return expected
stem = expected.stem
candidates = []
for suffix in (expected.suffix, '.so', '.dylib', '.dll', '.pyd'):
if not suffix:
continue
candidates.append(out_dir / f'{stem}{suffix}')
candidates.append(out_dir / f'lib{stem}{suffix}')
for cand in candidates:
if cand.exists():
return cand
for cand in sorted(out_dir.glob(f'{stem}*')):
if cand.is_file():
return cand
raise FileNotFoundError(f'Compiled library {libname} not found under {out_dir}')
def _to_treelite_model(booster: lgb.Booster) -> treelite.Model:
"""Convert LightGBM booster into a Treelite model, handling API differences."""
conversion_errors = []
frontend = getattr(treelite, 'frontend', None)
if frontend is not None and hasattr(frontend, 'from_lightgbm'):
try:
return frontend.from_lightgbm(booster)
except Exception as err:
conversion_errors.append(f'frontend.from_lightgbm: {err}')
model_cls = getattr(treelite, 'Model', None)
if model_cls is not None and hasattr(model_cls, 'from_lightgbm'):
try:
return model_cls.from_lightgbm(booster=booster)
except Exception as err:
conversion_errors.append(f'Model.from_lightgbm: {err}')
if not conversion_errors:
conversion_errors.append('no known conversion entry point')
raise RuntimeError('Treelite LightGBM conversion failed; tried paths: ' + ' | '.join(conversion_errors))
def export_lib_with_tl2cgen(
model: treelite.Model, out_dir: Path, libname: str, nthread: int
) -> Path | None:
libpath = out_dir / libname
tl2cgen.generate_c_code(model, dirpath=out_dir,
params={"quantize": 1})
export_kwargs = {
'toolchain': 'gcc',
'libpath': str(libpath),
'params': {'parallel_comp': nthread, 'quantize': 1},
'verbose': True,
'options': ['-O3']
}
compiled_path = None
last_error = None
for model_kw in ('model', 'tree_model', 'treelite_model'):
try:
result = _call_with_supported_kwargs(
tl2cgen.export_lib, **export_kwargs, **{model_kw: model}
)
if isinstance(result, (str, os.PathLike)):
compiled_path = Path(result)
last_error = None
break
except TypeError as err:
last_error = err
except Exception as err: # pragma: no cover - keep the most informative error
last_error = err
if last_error is not None:
raise RuntimeError(f'TL2cgen export failed: {last_error}')
if compiled_path is None:
compiled_path = _locate_compiled_library(out_dir, libname)
return compiled_path
def predictor_infer(libpath: Path, X: np.ndarray, nthread: int, pred_margin: bool, repeats: int):
"""Run inference through treelite_runtime with a compiled shared library."""
D = tl2cgen.DMatrix(X)
kwargs = {'verbose': False, 'nthread': nthread}
predictor = tl2cgen.Predictor(libpath=libpath, **kwargs)
# warm-up
for _ in range(3):
_ = predictor.predict(D, pred_margin=pred_margin)
t0 = time.perf_counter()
for _ in range(repeats):
out = predictor.predict(D, pred_margin=pred_margin)
t1 = time.perf_counter()
latency = (t1 - t0) / repeats
throughput = X.shape[0] / latency
return latency, throughput, out
def gtil_infer(model: treelite.Model, X: np.ndarray, nthread: int, pred_margin: bool, repeats: int):
"""Use Treelite's GTIL execution path when no compiled library is available."""
try:
from treelite import gtil
except ImportError as err: # pragma: no cover - SciPy missing, for example
raise RuntimeError('treelite.gtil requires optional dependencies (e.g. SciPy).') from err
# GTIL always rebuilds intermediate buffers, so do a short warm-up
for _ in range(3):
_ = gtil.predict(model, X, nthread=nthread, pred_margin=pred_margin)
t0 = time.perf_counter()
for _ in range(repeats):
out = gtil.predict(model, X, nthread=nthread, pred_margin=pred_margin)
t1 = time.perf_counter()
latency = (t1 - t0) / repeats
throughput = X.shape[0] / latency
return latency, throughput, out
def lightgbm_infer(booster: lgb.Booster, X: np.ndarray, pred_margin: bool, repeats: int):
"""
LightGBM純正のpredictで推論(ベースライン)
"""
# ウォームアップ
for _ in range(3):
_ = booster.predict(X, raw_score=pred_margin)
t0 = time.perf_counter()
for _ in range(repeats):
out = booster.predict(X, raw_score=pred_margin)
t1 = time.perf_counter()
latency = (t1 - t0) / repeats
throughput = X.shape[0] / latency
return latency, throughput, out
def main():
parser = argparse.ArgumentParser(description='Benchmark: LightGBM vs Treelite (GTIL/TL2cgen)')
parser.add_argument('--model-url', type=str, default='', help='公開LightGBMモデルのURL。空ならtoyで自動学習。')
parser.add_argument(
'--model-path', type=str, default='', help='ローカルのLightGBMモデル(.txt)へのパス。指定時はこれを使用。'
)
parser.add_argument('--tmpdir', type=str, default='', help='一時生成物の出力先(空ならtempfile使用)')
parser.add_argument(
'--nthread', type=int, default=max(1, os.cpu_count() or 1), help='推論およびコンパイルに使うスレッド数'
)
parser.add_argument(
'--rows', type=int, default=500_000, help='ベンチ用入力行数(toy学習時はこの値で推論データを作成)'
)
parser.add_argument('--features', type=int, default=50, help='toyデータの特徴量数(学習&推論の両方に使用)')
parser.add_argument('--repeats', type=int, default=5, help='同一入力での繰り返し回数(平均化)')
parser.add_argument('--pred-margin', action='store_true', help='raw score(margin)で出力する')
args = parser.parse_args()
# 出力先ディレクトリ
if args.tmpdir:
out_dir = Path(args.tmpdir)
out_dir.mkdir(parents=True, exist_ok=True)
else:
out_dir = Path(tempfile.mkdtemp(prefix='tl_bench_'))
log(f'Workdir: {out_dir}')
log(f'System: {platform.platform()} Python: {platform.python_version()} nthread={args.nthread}')
model_txt = out_dir / 'model_lgb.txt'
used_pretrained = False
# 1) 既存モデルの確保
booster = None
data_for_pred = None
if args.model_path:
# ローカル指定を最優先
booster = load_lightgbm_model(Path(args.model_path))
else:
# 公開モデルのダウンロードを試す(任意URL。無い/失敗ならtoy学習)
if args.model_url:
if try_download_model(args.model_url, model_txt):
try:
booster = load_lightgbm_model(model_txt)
used_pretrained = True
except Exception as e:
log(f'Downloaded model failed to load: {e}')
booster = None
if booster is None:
# toy学習へフォールバック
booster, (X_all, y_all) = train_toy_lgbm_model(
n_samples=max(args.rows, 200_000),
n_features=args.features,
)
booster.save_model(str(model_txt))
# 推論データは訓練と同次元の乱数でOK(汎化評価が目的ではないため)
rng = np.random.default_rng(2025)
data_for_pred = rng.standard_normal((args.rows, args.features)).astype(np.float32)
# 2) 推論データの用意(公開モデルを使う場合は形状に合わせるのが理想だが、ここでは汎用に乱数生成)
if data_for_pred is None:
# Boosterから特徴量数を推定できないため、ユーザ指定に従う
rng = np.random.default_rng(2025)
data_for_pred = rng.standard_normal((args.rows, args.features)).astype(np.float32)
log(f'Input matrix: shape={data_for_pred.shape}, dtype={data_for_pred.dtype}')
# 3) Treelite用にモデルを読み込み
try:
tl_model = _to_treelite_model(booster)
except Exception as e:
log(f'Treelite conversion failed: {e}')
sys.exit(1)
# TL2cgenで共有ライブラリを生成
lib_tl2cgen = None
try:
lib_tl2cgen = export_lib_with_tl2cgen(
tl_model, out_dir, 'model_tl2cgen.so', nthread=args.nthread
)
if lib_tl2cgen is not None:
log(f'TL2cgen shared library: {lib_tl2cgen}')
except Exception as e:
log(f'TL2cgen export failed: {e}')
lib_tl2cgen = None
# 4) ベンチ: LightGBM / Treelite (GTIL / TL2cgen)
results = []
# LightGBM純正
try:
lat, thr, out = lightgbm_infer(booster, data_for_pred, pred_margin=args.pred_margin, repeats=args.repeats)
results.append(
{
'impl': 'lightgbm.predict',
'latency_s_per_run': lat,
'throughput_rows_per_s': thr,
}
)
log(f'LightGBM: avg_latency={lat:.6f}s throughput={thr:,.0f} rows/s')
except Exception as e:
log(f'LightGBM inference failed: {e}')
# Treelite (GTIL)
try:
lat, thr, out = gtil_infer(
tl_model,
data_for_pred,
nthread=args.nthread,
pred_margin=args.pred_margin,
repeats=args.repeats,
)
results.append(
{
'impl': 'treelite.gtil',
'latency_s_per_run': lat,
'throughput_rows_per_s': thr,
}
)
log(f'Treelite GTIL: avg_latency={lat:.6f}s throughput={thr:,.0f} rows/s')
except Exception as e:
log(f'Treelite GTIL failed: {e}')
# Treelite runtime (TL2cgen shared library)
if lib_tl2cgen is not None:
if lib_tl2cgen.exists():
try:
lat, thr, out = predictor_infer(
lib_tl2cgen,
data_for_pred,
nthread=args.nthread,
pred_margin=args.pred_margin,
repeats=args.repeats,
)
results.append(
{
'impl': 'TL2cgen',
'latency_s_per_run': lat,
'throughput_rows_per_s': thr,
}
)
log(f'TL2cgen runtime: avg_latency={lat:.6f}s throughput={thr:,.0f} rows/s')
except Exception as e:
log(f'TL2cgen runtime failed: {e}')
else:
log(f'TL2cgen library missing on disk: {lib_tl2cgen}')
# 5) 結果をJSONでも保存
out_json = out_dir / 'benchmark_result.json'
with open(out_json, 'w', encoding='utf-8') as f:
json.dump(
{
'env': {
'python': platform.python_version(),
'platform': platform.platform(),
'nthread': args.nthread,
'used_pretrained_model': used_pretrained,
},
'shape': list(data_for_pred.shape),
'repeats': args.repeats,
'pred_margin': args.pred_margin,
'results': results,
},
f,
ensure_ascii=False,
indent=2,
)
log(f'Saved benchmark_result.json → {out_json}')
# 6) 簡易ランキング表示
if results:
best = sorted(results, key=lambda r: r['throughput_rows_per_s'], reverse=True)
log('Throughput ranking (higher is better):')
for i, r in enumerate(best, 1):
log(
f'{i}. {r["impl"]:<36} {r["throughput_rows_per_s"]:>12,.0f} rows/s '
f'(avg_latency={r["latency_s_per_run"]:.6f}s)'
)
else:
log('No successful runs.')
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment