Created
December 9, 2022 12:53
-
-
Save rversteegen/460640b658217c1434633d3d6509456b to your computer and use it in GitHub Desktop.
KunzEgg's Forward-Forward example in numpy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# using the Forward-Forward algorithm to train a neural network to classify positive and negative data | |
# the positive data is real data and the negative data is generated by the network itself | |
# the network is trained to have high goodness for positive data and low goodness for negative data | |
# the goodness is measured by the sum of the squared activities in a layer | |
# the network is trained to correctly classify input vectors as positive data or negative data | |
# the probability that an input vector is positive is given by applying the logistic function, σ to the goodness, minus some threshold, θ | |
# the negative data may be predicted by the neural net using top-down connections, or it may be supplied externally | |
import numpy as np | |
# Define the activation function and its derivative | |
def activation(x): | |
return np.maximum(0, x) | |
def activation_derivative(x): | |
return 1. * (x > 0) | |
# Define the goodness function (the sum of the squared activities in a layer) | |
def goodness(x): | |
return np.sum(x ** 2) | |
# Define the forward pass for the positive data | |
def forward_pass_positive(X, W1, W2): | |
# Forward pass | |
a1 = activation(np.dot(X, W1)) | |
a2 = activation(np.dot(a1, W2)) | |
return a1, a2 | |
# Define the forward pass for the negative data | |
def forward_pass_negative(X, W1, W2): | |
# Forward pass | |
a1 = activation(np.dot(X, W1)) | |
a2 = activation(np.dot(a1, W2)) | |
return a1, a2 | |
# Define the learning rate | |
learning_rate = 0.01 | |
# Define the threshold for the goodness | |
theta = 0.1 | |
# Define the number of epochs | |
epochs = 100 | |
# Generate the positive data | |
X = np.array([[0, 0, 1], [0, 1, 1], [1, 0, 1], [1, 1, 1]]) | |
# Generate the negative data | |
Xn = np.array([[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0]]) | |
# Initialize the weights | |
W1 = 2np.random.random((3, 4)) - 1 | |
W2 = 2np.random.random((4, 1)) - 1 | |
# Perform the positive and negative passes for each epoch | |
for j in range(epochs): | |
# Forward pass for the positive data | |
a1, a2 = forward_pass_positive(X, W1, W2) | |
# Forward pass for the negative data | |
a1n, a2n = forward_pass_negative(Xn, W1, W2) | |
# Calculate the goodness for the positive data | |
g1 = goodness(a1) | |
g2 = goodness(a2) | |
# Calculate the goodness for the negative data | |
g1n = goodness(a1n) | |
g2n = goodness(a2n) | |
# Calculate the probability that the input vector is positive data | |
p1 = 1/(1 + np.exp(-(g1 - theta))) | |
p2 = 1/(1 + np.exp(-(g2 - theta))) | |
# Calculate the probability that the input vector is negative data | |
p1n = 1/(1 + np.exp(-(g1n - theta))) | |
p2n = 1/(1 + np.exp(-(g2n - theta))) | |
# Calculate the error for the positive data | |
error2 = p2 - 1 | |
error1 = p1 - 1 | |
# Calculate the error for the negative data | |
error2n = p2n - 0 | |
error1n = p1n - 0 | |
# Calculate the delta for the positive data | |
delta2 = error2 * activation_derivative(a2) | |
delta1 = error1 * activation_derivative(a1) | |
# Calculate the delta for the negative data | |
delta2n = error2n * activation_derivative(a2n) | |
delta1n = error1n * activation_derivative(a1n) | |
# Calculate the change in the weights for the positive data | |
dW2 = learning_rate * a1.T.dot(delta2) | |
dW1 = learning_rate * X.T.dot(delta1) | |
# Calculate the change in the weights for the negative data | |
dW2n = learning_rate * a1n.T.dot(delta2n) | |
dW1n = learning_rate * Xn.T.dot(delta1n) | |
# Update the weights for the positive data | |
W2 += dW2 | |
W1 += dW1 | |
# Update the weights for the negative data | |
W2 += dW2n | |
W1 += dW1n | |
# Print the weights | |
print("W1 = ", W1) | |
print("W2 = ", W2) | |
# Print the goodness for the positive data | |
print("g1 = ", g1) | |
print("g2 = ", g2) | |
# Print the goodness for the negative data | |
print("g1n = ", g1n) | |
print("g2n = ", g2n) | |
# Print the probability that the input vector is positive data | |
print("p1 = ", p1) | |
print("p2 = ", p2) | |
# Print the probability that the input vector is negative data | |
print("p1n = ", p1n) | |
print("p2n = ", p2n) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment