Last active
September 19, 2019 17:50
-
-
Save aeftimia/5587286cb844953528b92bea0cd80bdb to your computer and use it in GitHub Desktop.
Convolutional Decision Tree
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy | |
| from keras.datasets import cifar10 | |
| from sklearn import preprocessing | |
| from sklearn.tree import DecisionTreeClassifier | |
| (X_train, y_train), (X_test, y_test) = cifar10.load_data() | |
| X_train = X_train.astype('float32') | |
| X_test = X_test.astype('float32') | |
| import numpy | |
| from sklearn import preprocessing | |
| from sklearn.tree import DecisionTreeClassifier | |
| def preprocess(X_batch, kernel=3, stride=2): | |
| slices = X_batch | |
| for dim in range(1, len(X_batch.shape[:-1])): | |
| slices = slices.swapaxes(0, dim) | |
| slices = numpy.stack([slices[i:i-kernel:stride] for i in range(kernel)], axis=-1) | |
| slices = slices.swapaxes(0, dim) | |
| return slices.reshape((slices.shape[:len(X_batch.shape) - 1] + (-1,))) | |
| def reshape(X, y): | |
| X = preprocess(X) | |
| s = X.shape[1:-1] | |
| return X.reshape(-1, X.shape[-1]), y.repeat(numpy.prod(s)), (len(y),) + s | |
| class ConvTree: | |
| def __init__(self, max_depth: list): | |
| self.layers = [DecisionTreeClassifier(max_depth=d, random_state=0, criterion='entropy') for d in max_depth] | |
| def fit(self, X_train, y_train): | |
| self.binarizer = [] | |
| reshaped_X_train_encoded, reshaped_y_train, shape = reshape(X_train, y_train) | |
| for i, layer in enumerate(self.layers): | |
| layer.fit(reshaped_X_train_encoded, reshaped_y_train) | |
| lb = preprocessing.LabelBinarizer() | |
| nodes = layer.tree_.apply(reshaped_X_train_encoded) | |
| lb.fit(nodes) | |
| self.binarizer.append(lb) | |
| if i < len(self.layers) - 2: | |
| reshaped_X_train_encoded, reshaped_y_train, shape = reshape(lb.transform(nodes).reshape(shape + (-1,)).astype('float32'), y_train) | |
| else: | |
| reshaped_X_train_encoded, reshaped_y_train = lb.transform(nodes).reshape(shape[:1] + (-1,)).astype('float32'), y_train | |
| def predict_proba(self, X_train): | |
| y_train = numpy.ones(len(X_train)) | |
| reshaped_X_train_encoded, reshaped_y_train, shape = reshape(X_train, y_train) | |
| for i, (layer, lb) in enumerate(zip(self.layers, self.binarizer)): | |
| print(reshaped_X_train_encoded.shape) | |
| nodes = layer.tree_.apply(reshaped_X_train_encoded) | |
| if i == len(self.layers) - 2: | |
| reshaped_X_train_encoded = lb.transform(nodes).reshape(shape[:1] + (-1,)).astype('float32') | |
| break | |
| else: | |
| reshaped_X_train_encoded, _, shape = reshape(lb.transform(nodes).reshape(shape + (-1,)).astype('float32'), y_train) | |
| return self.layers[-1].predict_proba(reshaped_X_train_encoded) | |
| def score(self, X_train, y_train): | |
| return (numpy.argmax(self.predict_proba(X_train), 1) == y_train).mean() | |
| depths = [2, 2, 4] | |
| convtree = ConvTree(depths) | |
| convtree.fit(X_train, y_train) | |
| convtree.score(X_test, y_test) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment