Skip to content

Instantly share code, notes, and snippets.

@andrea-dagostino
Created August 23, 2022 12:17
Show Gist options
  • Select an option

  • Save andrea-dagostino/df26a9486bbe2562796de03f04f4b247 to your computer and use it in GitHub Desktop.

Select an option

Save andrea-dagostino/df26a9486bbe2562796de03f04f4b247 to your computer and use it in GitHub Desktop.
train_accs = []
test_accs = []
cols = [
'fixed.acidity', 'volatile.acidity', 'citric.acid','residual.sugar', 'chlorides', 'free.sulfur.dioxide',
'total.sulfur.dioxide', 'density', 'pH', 'sulphates', 'alcohol',
]
# init a loop where we dynamically change the value of max_depth
for depth in range(1, 25):
clf = tree.DecisionTreeClassifier(max_depth=depth)
clf.fit(df_train[cols], df_train.quality)
train_predictions = clf.predict(df_train[cols])
test_predictions = clf.predict(df_test[cols])
train_acc = metrics.accuracy_score(df_train.quality, train_predictions)
test_acc = metrics.accuracy_score(df_test.quality, test_predictions)
# append the accuracies to the lists
train_accs.append(train_acc)
test_accs.append(test_acc)
# plot the data
plt.figure(figsize=(10, 5))
sns.set_style('whitegrid')
plt.plot(train_accs, label='train accuracy')
plt.plot(test_accs, label='test accuracy')
plt.legend(loc='upper left', prop={'size': 15})
plt.xticks(range(0, 26, 5))
plt.xlabel('max_depth', size=20)
plt.ylabel('accuracy', size=20)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment