Skip to content

Instantly share code, notes, and snippets.

@crawles
Created March 5, 2019 02:18
Show Gist options
  • Save crawles/7e2290c01a76b66cdc8e036f758a0ce0 to your computer and use it in GitHub Desktop.
Save crawles/7e2290c01a76b66cdc8e036f758a0ce0 to your computer and use it in GitHub Desktop.
params = {
'n_trees': 50,
'max_depth': 3,
'n_batches_per_layer': 1,
# You must enable center_bias = True to get DFCs. This will force the model to
# make an initial prediction before using any features (e.g. use the mean of
# the training labels for regression or log odds for classification when
# using cross entropy loss).
'center_bias': True
}
est = tf.estimator.BoostedTreesClassifier(feature_columns, **params)
# Train model.
est.train(train_input_fn, max_steps=100)
# Evaluation.
results = est.evaluate(eval_input_fn)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment