Created
August 5, 2016 13:29
-
-
Save lalinsky/a49b34994a850b5e647aa61f9095e1f9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"metadata": { | |
"name": "DecisionTreeClassificationModel to JSON" | |
}, | |
"nbformat": 3, | |
"nbformat_minor": 0, | |
"worksheets": [ | |
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": "from pyspark.ml.linalg import Vectors\nfrom pyspark.ml.feature import StringIndexer\nfrom pyspark.ml.classification import DecisionTreeClassifier\n\ndf = spark.createDataFrame([\n (1.0, Vectors.dense(1.0)),\n (0.0, Vectors.sparse(1, [], []))], [\"label\", \"features\"])\n\nstringIndexer = StringIndexer(inputCol=\"label\", outputCol=\"indexed\")\nsi_model = stringIndexer.fit(df)\ntd = si_model.transform(df)\n\ndt = DecisionTreeClassifier(maxDepth=2, labelCol=\"indexed\")\nmodel = dt.fit(td)", | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 1 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": "print model.toDebugString", | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": "DecisionTreeClassificationModel (uid=DecisionTreeClassifier_4c7db90739bec3e34a3c) of depth 1 with 3 nodes\n If (feature 0 <= 0.0)\n Predict: 0.0\n Else (feature 0 > 0.0)\n Predict: 1.0\n\n" | |
} | |
], | |
"prompt_number": 2 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": "def convert_node(node, features, rules, attribute=None):\n if node.getClass().getName() == 'org.apache.spark.ml.tree.LeafNode':\n return {\n 'attribute': attribute,\n 'rules': rules,\n 'prediction': node.prediction(),\n 'good': 0,\n 'bad': 0\n }\n\n split = node.split()\n original_attribute = attribute\n attribute = features[split.featureIndex()]\n\n if split.getClass().getName() == 'org.apache.spark.ml.tree.ContinuousSplit':\n threshold = split.threshold()\n left_rules = '{}<={}'.format(attribute, threshold)\n right_rules = '{}>{}'.format(attribute, threshold)\n elif split.getClass().getName() == 'org.apache.spark.ml.tree.CategoricalSplit':\n categories = split.leftCategories().mkString(\"{\", \",\", \"}\")\n left_rules = '{} in {}'.format(attribute, categories)\n right_rules = '{} not in {}'.format(attribute, categories)\n else:\n raise ValueError('unknown split class')\n \n children = [\n convert_node(node.leftChild(), features, left_rules, attribute),\n convert_node(node.rightChild(), features, right_rules, attribute)\n ]\n\n return {\n 'attribute_id': split.featureIndex(),\n 'attribute': original_attribute,\n 'threshold': threshold,\n 'prediction': node.prediction(),\n 'good': 0,\n 'bad': 0,\n 'rules': rules,\n 'children': children\n }\n\ndef convert_model(model, features):\n return convert_node(model._call_java('rootNode'), features, '')\n\nimport pprint\npprint.pprint(convert_model(model, ['feature1']))", | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": "{'attribute': None,\n 'attribute_id': 0,\n 'bad': 0,\n 'children': [{'attribute': 'feature1',\n 'bad': 0,\n 'good': 0,\n 'prediction': 0.0,\n 'rules': 'feature1<=0.0'},\n {'attribute': 'feature1',\n 'bad': 0,\n 'good': 0,\n 'prediction': 1.0,\n 'rules': 'feature1>0.0'}],\n 'good': 0,\n 'prediction': 0.0,\n 'rules': '',\n 'threshold': 0.0}\n" | |
} | |
], | |
"prompt_number": 3 | |
} | |
], | |
"metadata": {} | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment