Last active
November 13, 2021 00:29
-
-
Save innat/ed748ef99505f3b66dd7c6f7cfc3e7ad to your computer and use it in GitHub Desktop.
TF 2 Implementation.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Reference: https://keras.io/examples/structured_data/deep_neural_decision_forests/ | |
import tensorflow as tf | |
from tensorflow.keras import layers | |
from tensorflow import keras | |
class NeuralDecisionTree(keras.Model): | |
def __init__(self, depth, num_features, used_features_rate, num_classes): | |
super(NeuralDecisionTree, self).__init__() | |
self.depth = depth | |
self.num_leaves = 2 ** depth | |
self.num_classes = num_classes | |
# Create a mask for the randomly selected features. | |
num_used_features = int(num_features * used_features_rate) | |
one_hot = np.eye(num_features) | |
sampled_feature_indicies = np.random.choice( | |
np.arange(num_features), num_used_features, replace=False | |
) | |
self.used_features_mask = one_hot[sampled_feature_indicies] | |
# Initialize the weights of the classes in leaves. | |
self.pi = tf.Variable( | |
initial_value=tf.random_normal_initializer()( | |
shape=[self.num_leaves, self.num_classes] | |
), | |
dtype="float32", | |
trainable=True, | |
) | |
# Initialize the stochastic routing layer. | |
self.decision_fn = layers.Dense( | |
units=self.num_leaves, activation="sigmoid", name="decision" | |
) | |
def call(self, features): | |
batch_size = tf.shape(features)[0] | |
# Apply the feature mask to the input features. | |
features = tf.matmul( | |
features, self.used_features_mask, transpose_b=True | |
) # [batch_size, num_used_features] | |
# Compute the routing probabilities. | |
decisions = tf.expand_dims( | |
self.decision_fn(features), axis=2 | |
) # [batch_size, num_leaves, 1] | |
# Concatenate the routing probabilities with their complements. | |
decisions = layers.concatenate( | |
[decisions, 1 - decisions], axis=2 | |
) # [batch_size, num_leaves, 2] | |
mu = tf.ones([batch_size, 1, 1]) | |
begin_idx = 1 | |
end_idx = 2 | |
# Traverse the tree in breadth-first order. | |
for level in range(self.depth): | |
mu = tf.reshape(mu, [batch_size, -1, 1]) # [batch_size, 2 ** level, 1] | |
mu = tf.tile(mu, (1, 1, 2)) # [batch_size, 2 ** level, 2] | |
level_decisions = decisions[ | |
:, begin_idx:end_idx, : | |
] # [batch_size, 2 ** level, 2] | |
mu = mu * level_decisions # [batch_size, 2**level, 2] | |
begin_idx = end_idx | |
end_idx = begin_idx + 2 ** (level + 1) | |
mu = tf.reshape(mu, [batch_size, self.num_leaves]) # [batch_size, num_leaves] | |
probabilities = keras.activations.softmax(self.pi) # [num_leaves, num_classes] | |
outputs = tf.matmul(mu, probabilities) # [batch_size, num_classes] | |
return outputs | |
def build_graph(self): | |
x = tf.keras.Input(shape=(12,)) | |
return tf.keras.Model(inputs=[x], outputs=self.call(x)) | |
class NeuralDecisionForest(keras.Model): | |
def __init__(self, num_trees, depth, num_features, used_features_rate, num_classes): | |
super(NeuralDecisionForest, self).__init__() | |
self.ensemble = [] | |
# Initialize the ensemble by adding NeuralDecisionTree instances. | |
# Each tree will have its own randomly selected input features to use. | |
for _ in range(num_trees): | |
self.ensemble.append( | |
NeuralDecisionTree(depth, num_features, used_features_rate, num_classes) | |
) | |
def call(self, inputs): | |
# Initialize the outputs: a [batch_size, num_classes] matrix of zeros. | |
batch_size = tf.shape(inputs)[0] | |
outputs = tf.zeros([batch_size, num_classes]) | |
# Aggregate the outputs of trees in the ensemble. | |
for tree in self.ensemble: | |
outputs += tree(inputs) | |
# Divide the outputs by the ensemble size to get the average. | |
outputs /= len(self.ensemble) | |
return outputs | |
def build_graph(self): | |
x = tf.keras.Input(shape=(12,)) | |
return tf.keras.Model(inputs=[x], outputs=self.call(x)) | |
inputs = tf.keras.Input(shape=(12,)) | |
# tree = NeuralDecisionTree(depth=10, num_features=12, used_features_rate=1.0, num_classes=1) | |
tree = NeuralDecisionForest(num_trees=25, depth=10, num_features=12, used_features_rate=1.0, num_classes=1) | |
outputs = tree(inputs) | |
model = keras.Model(inputs=inputs, outputs=outputs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment