Skip to content

Instantly share code, notes, and snippets.

@rversteegen
Created December 9, 2022 12:53
Show Gist options
  • Save rversteegen/460640b658217c1434633d3d6509456b to your computer and use it in GitHub Desktop.
Save rversteegen/460640b658217c1434633d3d6509456b to your computer and use it in GitHub Desktop.
KunzEgg's Forward-Forward example in numpy
# using the Forward-Forward algorithm to train a neural network to classify positive and negative data
# the positive data is real data and the negative data is generated by the network itself
# the network is trained to have high goodness for positive data and low goodness for negative data
# the goodness is measured by the sum of the squared activities in a layer
# the network is trained to correctly classify input vectors as positive data or negative data
# the probability that an input vector is positive is given by applying the logistic function, σ to the goodness, minus some threshold, θ
# the negative data may be predicted by the neural net using top-down connections, or it may be supplied externally
import numpy as np
# Define the activation function and its derivative
def activation(x):
return np.maximum(0, x)
def activation_derivative(x):
return 1. * (x > 0)
# Define the goodness function (the sum of the squared activities in a layer)
def goodness(x):
return np.sum(x ** 2)
# Define the forward pass for the positive data
def forward_pass_positive(X, W1, W2):
# Forward pass
a1 = activation(np.dot(X, W1))
a2 = activation(np.dot(a1, W2))
return a1, a2
# Define the forward pass for the negative data
def forward_pass_negative(X, W1, W2):
# Forward pass
a1 = activation(np.dot(X, W1))
a2 = activation(np.dot(a1, W2))
return a1, a2
# Define the learning rate
learning_rate = 0.01
# Define the threshold for the goodness
theta = 0.1
# Define the number of epochs
epochs = 100
# Generate the positive data
X = np.array([[0, 0, 1], [0, 1, 1], [1, 0, 1], [1, 1, 1]])
# Generate the negative data
Xn = np.array([[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0]])
# Initialize the weights
W1 = 2np.random.random((3, 4)) - 1
W2 = 2np.random.random((4, 1)) - 1
# Perform the positive and negative passes for each epoch
for j in range(epochs):
# Forward pass for the positive data
a1, a2 = forward_pass_positive(X, W1, W2)
# Forward pass for the negative data
a1n, a2n = forward_pass_negative(Xn, W1, W2)
# Calculate the goodness for the positive data
g1 = goodness(a1)
g2 = goodness(a2)
# Calculate the goodness for the negative data
g1n = goodness(a1n)
g2n = goodness(a2n)
# Calculate the probability that the input vector is positive data
p1 = 1/(1 + np.exp(-(g1 - theta)))
p2 = 1/(1 + np.exp(-(g2 - theta)))
# Calculate the probability that the input vector is negative data
p1n = 1/(1 + np.exp(-(g1n - theta)))
p2n = 1/(1 + np.exp(-(g2n - theta)))
# Calculate the error for the positive data
error2 = p2 - 1
error1 = p1 - 1
# Calculate the error for the negative data
error2n = p2n - 0
error1n = p1n - 0
# Calculate the delta for the positive data
delta2 = error2 * activation_derivative(a2)
delta1 = error1 * activation_derivative(a1)
# Calculate the delta for the negative data
delta2n = error2n * activation_derivative(a2n)
delta1n = error1n * activation_derivative(a1n)
# Calculate the change in the weights for the positive data
dW2 = learning_rate * a1.T.dot(delta2)
dW1 = learning_rate * X.T.dot(delta1)
# Calculate the change in the weights for the negative data
dW2n = learning_rate * a1n.T.dot(delta2n)
dW1n = learning_rate * Xn.T.dot(delta1n)
# Update the weights for the positive data
W2 += dW2
W1 += dW1
# Update the weights for the negative data
W2 += dW2n
W1 += dW1n
# Print the weights
print("W1 = ", W1)
print("W2 = ", W2)
# Print the goodness for the positive data
print("g1 = ", g1)
print("g2 = ", g2)
# Print the goodness for the negative data
print("g1n = ", g1n)
print("g2n = ", g2n)
# Print the probability that the input vector is positive data
print("p1 = ", p1)
print("p2 = ", p2)
# Print the probability that the input vector is negative data
print("p1n = ", p1n)
print("p2n = ", p2n)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment