Last active
October 9, 2025 01:42
-
-
Save serihiro/2e45038de2d4dad1b0b7bded5c7bd00b to your computer and use it in GitHub Desktop.
lgbm_inference_benchmark.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import argparse | |
| import os | |
| import sys | |
| import time | |
| import tempfile | |
| import json | |
| import urllib.request | |
| import shutil | |
| import platform | |
| from pathlib import Path | |
| import inspect | |
| import numpy as np | |
| # 依存ライブラリ | |
| import lightgbm as lgb | |
| from sklearn.datasets import make_classification | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import roc_auc_score | |
| import treelite | |
| import tl2cgen | |
| def log(msg): | |
| print(f'[INFO] {msg}', flush=True) | |
| def try_download_model(url: str, dst: Path) -> bool: | |
| try: | |
| log(f'Downloading LightGBM model from: {url}') | |
| with urllib.request.urlopen(url, timeout=30) as resp, open(dst, 'wb') as f: | |
| shutil.copyfileobj(resp, f) | |
| log(f'Saved to: {dst}') | |
| return True | |
| except Exception as e: | |
| log(f'Download failed: {e}') | |
| return False | |
| def train_toy_lgbm_model( | |
| n_samples=200_000, | |
| n_features=50, | |
| n_informative=30, | |
| random_state=42, | |
| num_leaves=64, | |
| n_estimators=500, | |
| learning_rate=0.05, | |
| test_size=0.2, | |
| ): | |
| log('Training toy LightGBM model on synthetic classification data...') | |
| X, y = make_classification( | |
| n_samples=n_samples, | |
| n_features=n_features, | |
| n_informative=n_informative, | |
| n_redundant=n_features - n_informative, | |
| random_state=random_state, | |
| ) | |
| X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=test_size, random_state=random_state) | |
| train_set = lgb.Dataset(X_train, label=y_train) | |
| valid_set = lgb.Dataset(X_valid, label=y_valid) | |
| params = dict( | |
| objective='binary', | |
| metric='auc', | |
| num_leaves=num_leaves, | |
| learning_rate=learning_rate, | |
| verbosity=-1, | |
| ) | |
| booster = lgb.train( | |
| params, | |
| train_set, | |
| num_boost_round=n_estimators, | |
| valid_sets=[valid_set], | |
| valid_names=['valid'], | |
| ) | |
| # AUCの簡易表示(plot_metricは学習履歴が必要なのでシンプルに算出する) | |
| auc = roc_auc_score(y_valid, booster.predict(X_valid)) | |
| log(f'Toy model trained. valid AUC ~ {auc:.4f}') | |
| return booster, (X, y) | |
| def load_lightgbm_model(model_path: Path): | |
| log(f'Loading LightGBM model: {model_path}') | |
| booster = lgb.Booster(model_file=str(model_path)) | |
| return booster | |
| def _call_with_supported_kwargs(func, **kwargs): | |
| """Call func with only the kwargs it supports to handle version differences.""" | |
| try: | |
| sig = inspect.signature(func) | |
| except (TypeError, ValueError): | |
| return func(**kwargs) | |
| if any(p.kind == inspect.Parameter.VAR_POSITIONAL for p in sig.parameters.values()): | |
| return func(**kwargs) | |
| filtered = {} | |
| for name, param in sig.parameters.items(): | |
| if param.kind == inspect.Parameter.VAR_KEYWORD: | |
| return func(**kwargs) | |
| if name in kwargs: | |
| filtered[name] = kwargs[name] | |
| return func(**filtered) | |
| def _locate_compiled_library(out_dir: Path, libname: str) -> Path: | |
| """Locate the shared library emitted by TL2cgen, accounting for name variations.""" | |
| expected = out_dir / libname | |
| if expected.exists(): | |
| return expected | |
| stem = expected.stem | |
| candidates = [] | |
| for suffix in (expected.suffix, '.so', '.dylib', '.dll', '.pyd'): | |
| if not suffix: | |
| continue | |
| candidates.append(out_dir / f'{stem}{suffix}') | |
| candidates.append(out_dir / f'lib{stem}{suffix}') | |
| for cand in candidates: | |
| if cand.exists(): | |
| return cand | |
| for cand in sorted(out_dir.glob(f'{stem}*')): | |
| if cand.is_file(): | |
| return cand | |
| raise FileNotFoundError(f'Compiled library {libname} not found under {out_dir}') | |
| def _to_treelite_model(booster: lgb.Booster) -> treelite.Model: | |
| """Convert LightGBM booster into a Treelite model, handling API differences.""" | |
| conversion_errors = [] | |
| frontend = getattr(treelite, 'frontend', None) | |
| if frontend is not None and hasattr(frontend, 'from_lightgbm'): | |
| try: | |
| return frontend.from_lightgbm(booster) | |
| except Exception as err: | |
| conversion_errors.append(f'frontend.from_lightgbm: {err}') | |
| model_cls = getattr(treelite, 'Model', None) | |
| if model_cls is not None and hasattr(model_cls, 'from_lightgbm'): | |
| try: | |
| return model_cls.from_lightgbm(booster=booster) | |
| except Exception as err: | |
| conversion_errors.append(f'Model.from_lightgbm: {err}') | |
| if not conversion_errors: | |
| conversion_errors.append('no known conversion entry point') | |
| raise RuntimeError('Treelite LightGBM conversion failed; tried paths: ' + ' | '.join(conversion_errors)) | |
| def export_lib_with_tl2cgen( | |
| model: treelite.Model, out_dir: Path, libname: str, nthread: int | |
| ) -> Path | None: | |
| libpath = out_dir / libname | |
| tl2cgen.generate_c_code(model, dirpath=out_dir, | |
| params={"quantize": 1}) | |
| export_kwargs = { | |
| 'toolchain': 'gcc', | |
| 'libpath': str(libpath), | |
| 'params': {'parallel_comp': nthread, 'quantize': 1}, | |
| 'verbose': True, | |
| 'options': ['-O3'] | |
| } | |
| compiled_path = None | |
| last_error = None | |
| for model_kw in ('model', 'tree_model', 'treelite_model'): | |
| try: | |
| result = _call_with_supported_kwargs( | |
| tl2cgen.export_lib, **export_kwargs, **{model_kw: model} | |
| ) | |
| if isinstance(result, (str, os.PathLike)): | |
| compiled_path = Path(result) | |
| last_error = None | |
| break | |
| except TypeError as err: | |
| last_error = err | |
| except Exception as err: # pragma: no cover - keep the most informative error | |
| last_error = err | |
| if last_error is not None: | |
| raise RuntimeError(f'TL2cgen export failed: {last_error}') | |
| if compiled_path is None: | |
| compiled_path = _locate_compiled_library(out_dir, libname) | |
| return compiled_path | |
| def predictor_infer(libpath: Path, X: np.ndarray, nthread: int, pred_margin: bool, repeats: int): | |
| """Run inference through treelite_runtime with a compiled shared library.""" | |
| D = tl2cgen.DMatrix(X) | |
| kwargs = {'verbose': False, 'nthread': nthread} | |
| predictor = tl2cgen.Predictor(libpath=libpath, **kwargs) | |
| # warm-up | |
| for _ in range(3): | |
| _ = predictor.predict(D, pred_margin=pred_margin) | |
| t0 = time.perf_counter() | |
| for _ in range(repeats): | |
| out = predictor.predict(D, pred_margin=pred_margin) | |
| t1 = time.perf_counter() | |
| latency = (t1 - t0) / repeats | |
| throughput = X.shape[0] / latency | |
| return latency, throughput, out | |
| def gtil_infer(model: treelite.Model, X: np.ndarray, nthread: int, pred_margin: bool, repeats: int): | |
| """Use Treelite's GTIL execution path when no compiled library is available.""" | |
| try: | |
| from treelite import gtil | |
| except ImportError as err: # pragma: no cover - SciPy missing, for example | |
| raise RuntimeError('treelite.gtil requires optional dependencies (e.g. SciPy).') from err | |
| # GTIL always rebuilds intermediate buffers, so do a short warm-up | |
| for _ in range(3): | |
| _ = gtil.predict(model, X, nthread=nthread, pred_margin=pred_margin) | |
| t0 = time.perf_counter() | |
| for _ in range(repeats): | |
| out = gtil.predict(model, X, nthread=nthread, pred_margin=pred_margin) | |
| t1 = time.perf_counter() | |
| latency = (t1 - t0) / repeats | |
| throughput = X.shape[0] / latency | |
| return latency, throughput, out | |
| def lightgbm_infer(booster: lgb.Booster, X: np.ndarray, pred_margin: bool, repeats: int): | |
| """ | |
| LightGBM純正のpredictで推論(ベースライン) | |
| """ | |
| # ウォームアップ | |
| for _ in range(3): | |
| _ = booster.predict(X, raw_score=pred_margin) | |
| t0 = time.perf_counter() | |
| for _ in range(repeats): | |
| out = booster.predict(X, raw_score=pred_margin) | |
| t1 = time.perf_counter() | |
| latency = (t1 - t0) / repeats | |
| throughput = X.shape[0] / latency | |
| return latency, throughput, out | |
| def main(): | |
| parser = argparse.ArgumentParser(description='Benchmark: LightGBM vs Treelite (GTIL/TL2cgen)') | |
| parser.add_argument('--model-url', type=str, default='', help='公開LightGBMモデルのURL。空ならtoyで自動学習。') | |
| parser.add_argument( | |
| '--model-path', type=str, default='', help='ローカルのLightGBMモデル(.txt)へのパス。指定時はこれを使用。' | |
| ) | |
| parser.add_argument('--tmpdir', type=str, default='', help='一時生成物の出力先(空ならtempfile使用)') | |
| parser.add_argument( | |
| '--nthread', type=int, default=max(1, os.cpu_count() or 1), help='推論およびコンパイルに使うスレッド数' | |
| ) | |
| parser.add_argument( | |
| '--rows', type=int, default=500_000, help='ベンチ用入力行数(toy学習時はこの値で推論データを作成)' | |
| ) | |
| parser.add_argument('--features', type=int, default=50, help='toyデータの特徴量数(学習&推論の両方に使用)') | |
| parser.add_argument('--repeats', type=int, default=5, help='同一入力での繰り返し回数(平均化)') | |
| parser.add_argument('--pred-margin', action='store_true', help='raw score(margin)で出力する') | |
| args = parser.parse_args() | |
| # 出力先ディレクトリ | |
| if args.tmpdir: | |
| out_dir = Path(args.tmpdir) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| else: | |
| out_dir = Path(tempfile.mkdtemp(prefix='tl_bench_')) | |
| log(f'Workdir: {out_dir}') | |
| log(f'System: {platform.platform()} Python: {platform.python_version()} nthread={args.nthread}') | |
| model_txt = out_dir / 'model_lgb.txt' | |
| used_pretrained = False | |
| # 1) 既存モデルの確保 | |
| booster = None | |
| data_for_pred = None | |
| if args.model_path: | |
| # ローカル指定を最優先 | |
| booster = load_lightgbm_model(Path(args.model_path)) | |
| else: | |
| # 公開モデルのダウンロードを試す(任意URL。無い/失敗ならtoy学習) | |
| if args.model_url: | |
| if try_download_model(args.model_url, model_txt): | |
| try: | |
| booster = load_lightgbm_model(model_txt) | |
| used_pretrained = True | |
| except Exception as e: | |
| log(f'Downloaded model failed to load: {e}') | |
| booster = None | |
| if booster is None: | |
| # toy学習へフォールバック | |
| booster, (X_all, y_all) = train_toy_lgbm_model( | |
| n_samples=max(args.rows, 200_000), | |
| n_features=args.features, | |
| ) | |
| booster.save_model(str(model_txt)) | |
| # 推論データは訓練と同次元の乱数でOK(汎化評価が目的ではないため) | |
| rng = np.random.default_rng(2025) | |
| data_for_pred = rng.standard_normal((args.rows, args.features)).astype(np.float32) | |
| # 2) 推論データの用意(公開モデルを使う場合は形状に合わせるのが理想だが、ここでは汎用に乱数生成) | |
| if data_for_pred is None: | |
| # Boosterから特徴量数を推定できないため、ユーザ指定に従う | |
| rng = np.random.default_rng(2025) | |
| data_for_pred = rng.standard_normal((args.rows, args.features)).astype(np.float32) | |
| log(f'Input matrix: shape={data_for_pred.shape}, dtype={data_for_pred.dtype}') | |
| # 3) Treelite用にモデルを読み込み | |
| try: | |
| tl_model = _to_treelite_model(booster) | |
| except Exception as e: | |
| log(f'Treelite conversion failed: {e}') | |
| sys.exit(1) | |
| # TL2cgenで共有ライブラリを生成 | |
| lib_tl2cgen = None | |
| try: | |
| lib_tl2cgen = export_lib_with_tl2cgen( | |
| tl_model, out_dir, 'model_tl2cgen.so', nthread=args.nthread | |
| ) | |
| if lib_tl2cgen is not None: | |
| log(f'TL2cgen shared library: {lib_tl2cgen}') | |
| except Exception as e: | |
| log(f'TL2cgen export failed: {e}') | |
| lib_tl2cgen = None | |
| # 4) ベンチ: LightGBM / Treelite (GTIL / TL2cgen) | |
| results = [] | |
| # LightGBM純正 | |
| try: | |
| lat, thr, out = lightgbm_infer(booster, data_for_pred, pred_margin=args.pred_margin, repeats=args.repeats) | |
| results.append( | |
| { | |
| 'impl': 'lightgbm.predict', | |
| 'latency_s_per_run': lat, | |
| 'throughput_rows_per_s': thr, | |
| } | |
| ) | |
| log(f'LightGBM: avg_latency={lat:.6f}s throughput={thr:,.0f} rows/s') | |
| except Exception as e: | |
| log(f'LightGBM inference failed: {e}') | |
| # Treelite (GTIL) | |
| try: | |
| lat, thr, out = gtil_infer( | |
| tl_model, | |
| data_for_pred, | |
| nthread=args.nthread, | |
| pred_margin=args.pred_margin, | |
| repeats=args.repeats, | |
| ) | |
| results.append( | |
| { | |
| 'impl': 'treelite.gtil', | |
| 'latency_s_per_run': lat, | |
| 'throughput_rows_per_s': thr, | |
| } | |
| ) | |
| log(f'Treelite GTIL: avg_latency={lat:.6f}s throughput={thr:,.0f} rows/s') | |
| except Exception as e: | |
| log(f'Treelite GTIL failed: {e}') | |
| # Treelite runtime (TL2cgen shared library) | |
| if lib_tl2cgen is not None: | |
| if lib_tl2cgen.exists(): | |
| try: | |
| lat, thr, out = predictor_infer( | |
| lib_tl2cgen, | |
| data_for_pred, | |
| nthread=args.nthread, | |
| pred_margin=args.pred_margin, | |
| repeats=args.repeats, | |
| ) | |
| results.append( | |
| { | |
| 'impl': 'TL2cgen', | |
| 'latency_s_per_run': lat, | |
| 'throughput_rows_per_s': thr, | |
| } | |
| ) | |
| log(f'TL2cgen runtime: avg_latency={lat:.6f}s throughput={thr:,.0f} rows/s') | |
| except Exception as e: | |
| log(f'TL2cgen runtime failed: {e}') | |
| else: | |
| log(f'TL2cgen library missing on disk: {lib_tl2cgen}') | |
| # 5) 結果をJSONでも保存 | |
| out_json = out_dir / 'benchmark_result.json' | |
| with open(out_json, 'w', encoding='utf-8') as f: | |
| json.dump( | |
| { | |
| 'env': { | |
| 'python': platform.python_version(), | |
| 'platform': platform.platform(), | |
| 'nthread': args.nthread, | |
| 'used_pretrained_model': used_pretrained, | |
| }, | |
| 'shape': list(data_for_pred.shape), | |
| 'repeats': args.repeats, | |
| 'pred_margin': args.pred_margin, | |
| 'results': results, | |
| }, | |
| f, | |
| ensure_ascii=False, | |
| indent=2, | |
| ) | |
| log(f'Saved benchmark_result.json → {out_json}') | |
| # 6) 簡易ランキング表示 | |
| if results: | |
| best = sorted(results, key=lambda r: r['throughput_rows_per_s'], reverse=True) | |
| log('Throughput ranking (higher is better):') | |
| for i, r in enumerate(best, 1): | |
| log( | |
| f'{i}. {r["impl"]:<36} {r["throughput_rows_per_s"]:>12,.0f} rows/s ' | |
| f'(avg_latency={r["latency_s_per_run"]:.6f}s)' | |
| ) | |
| else: | |
| log('No successful runs.') | |
| if __name__ == '__main__': | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment