Skip to content

Instantly share code, notes, and snippets.

@micahmelling
Created August 17, 2020 03:48
Show Gist options
  • Select an option

  • Save micahmelling/dd4224da0b699c439b2c71c8bc243333 to your computer and use it in GitHub Desktop.

Select an option

Save micahmelling/dd4224da0b699c439b2c71c8bc243333 to your computer and use it in GitHub Desktop.
import pandas as pd
import subprocess
from sklearn.tree import DecisionTreeClassifier, export_graphviz
def main():
df = pd.DataFrame({
'location': ['MO', 'MO', 'MO', 'MO', 'KS', 'KS', 'KS', 'IL', 'IL', 'IL'],
'age': [29, 30, 21, 40, 45, 60, 35, 24, 47, 50],
'royals_fan': [1, 1, 1, 0, 1, 1, 0, 0, 0, 0]
})
df = pd.get_dummies(df)
model = DecisionTreeClassifier()
y = df['royals_fan']
x = df.drop(['royals_fan'], 1)
model.fit(x, y)
export_graphviz(model, out_file='decision_tree.dot', class_names=['non_royals_fan', 'royals_fan'],
feature_names=list(x))
subprocess.call(['dot', '-Tpng', 'decision_tree.dot', '-o', 'decision_tree.png'])
subprocess.call(['rm', 'decision_tree.dot'])
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment