Created
June 22, 2018 12:42
-
-
Save wingrime/12aa614665c88f4d0254a66efc4ba745 to your computer and use it in GitHub Desktop.
Catboost with yeti-rank
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import mmh3 | |
import pandas as pd | |
def get_positive_hash(x): | |
s = " ".join(get_unique_tokens(x)) | |
return mmh3.hash(s) % 2**31 | |
df['group_id'] = df['query_string'].apply(get_positive_hash ) | |
query_groups = df.groupby("group_id") | |
group_list = list(query_groups.groups) | |
train_groups, validation_groups = train_test_split(group_list, test_size = 0.2) | |
def get_group_with_index(x): | |
return query_groups.get_group(x) | |
train_queries = pd.concat([get_group_with_index(x) for x in train_groups], axis = 0 ) | |
validation_queries =pd.concat( [get_group_with_index(x) for x in validation_groups], axis = 0) | |
y_train = train_queries['target'] | |
y_validation = validation_queries['target'] | |
group_id_train = train_queries['group_id'] | |
group_id_validation = validation_queries['group_id'] | |
X_train = train_queries.drop(non_train_variables, axis=1) | |
X_validation = validation_queries.drop(non_train_variables, axis=1) | |
cat_features= [0,1,2,5] | |
validation_pool = Pool(data=X_validation, label=y_validation , cat_features = cat_features , group_id = group_id_validation) | |
train_pool = Pool(data=X_train, label=y_train , cat_features = cat_features , group_id = group_id_train) | |
from catboost import CatBoost | |
param = {'loss_function':'YetiRank', | |
'learning_rate': .00001, | |
'iterations': 1000, | |
'depth': 7, | |
'use_best_model':True, | |
'calc_feature_importance': True, | |
'one_hot_max_size' : 20, | |
'bagging_temperature' : .7, | |
'max_ctr_complexity' : 4 } | |
model = CatBoost(param) | |
model.fit(train_pool, eval_set=validation_pool, | |
logging_level='Silent', | |
plot=True); | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks for the example!
Can you describe how the labels look like? from the documentation I understand that they can be 0 or 1 only, but what if I have more than 2 ranks? for example from 1 to 5?
Also, is it possible to work with sample weights as well?
Thanks,
Uriel