Skip to content

Instantly share code, notes, and snippets.

@rizplate
Forked from goraj/incremental_lightgbm.py
Created July 24, 2018 22:59
Show Gist options
  • Save rizplate/1c3d3556a82e144c1d6bb5b8951f17cc to your computer and use it in GitHub Desktop.
Save rizplate/1c3d3556a82e144c1d6bb5b8951f17cc to your computer and use it in GitHub Desktop.
incremental learning lightgbm
# -*- coding: utf-8 -*-
"""
@author: goraj
"""
import lightgbm as lgbm
from sklearn.datasets import load_digits
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
def iterative_ds(params, X_train, X_test, y_train, y_test):
# initialize model
ss = 5
estimator = None
for iteration, x in enumerate(range(0, X_train.shape[0] - ss, ss)):
indices = list(range(x, x + ss))
estimator = lgbm.train(params,
init_model=estimator,
train_set=lgbm.Dataset(X_train[indices], y_train[indices]),
keep_training_booster=True,
num_boost_round=5)
auc = roc_auc_score(y_test, estimator.predict(X_test) >= 0.5)
if iteration % 15 == 0:
print('iteration: {} auc: {}'.format(iteration, auc))
def iterative_subset(params, X_train, X_test, y_train, y_test):
# using subset
ss = 5
estimator = None
dset = lgbm.Dataset(X_train, y_train, free_raw_data=False)
for iteration, x in enumerate(range(0, X_train.shape[0] - ss, ss)):
indices = list(range(x, x + ss))
estimator = lgbm.train(params,
init_model=estimator,
train_set=dset.subset(indices),
keep_training_booster=True,
num_boost_round=5)
auc = roc_auc_score(y_test, estimator.predict(X_test) >= 0.5)
if iteration % 15 == 0:
print('iteration: {} auc: {}'.format(iteration, auc))
if __name__ == '__main__':
d = load_digits()
xs = d['data']
ys = d['target']
indices = np.where((ys == 1) | (ys == 0))
X = xs[indices]
y = ys[indices]
X_train, X_test, y_train, y_test = train_test_split(X,
y,
test_size=0.20,
random_state=42)
params = {
'boosting_type': 'gbdt',
'objective': 'binary',
'learning_rate': 0.01,
'num_leaves': 35,
'metric': 'auc',
'is_unbalance': False,
'seed': 1024,
'verbosity': -1,
'min_data': 1,
'min_data_in_bin': 1,
'free_raw_data': False
}
# works
iterative_ds(params, X_train, X_test, y_train, y_test)
# does not work
iterative_subset(params, X_train, X_test, y_train, y_test)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment