Last active
December 8, 2018 15:21
-
-
Save hugobrilhante/8a63fda161bdbc0729e3c233cc8f2972 to your computer and use it in GitHub Desktop.
Perceptron
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def activation(output): | |
| """ | |
| Activation function | |
| :param output: Perceptron algorithm result | |
| :return: 1 if output >= 0 and -1 if output < 0 | |
| """ | |
| if output >= 0: | |
| res = 1 | |
| else: | |
| res = -1 | |
| return res | |
| def perceptron(xs, ws): | |
| """ | |
| Perceptron Algorithm | |
| :param xs: values for training | |
| :param ws: weights for each value | |
| :return: a possible class | |
| """ | |
| output = 0 | |
| for x, w in zip(xs, ws): | |
| output += x * w | |
| return activation(output) | |
| def calculate_weights(inputs, weights, n, e): | |
| """ | |
| Calculate weights | |
| :param inputs: current Entries | |
| :param weights: current weights | |
| :param n: learning rate | |
| :param e: error | |
| :return: recalculated weights | |
| """ | |
| for i, w in enumerate(weights): | |
| weights[i] = w + n * e * inputs[i] | |
| def training(data_set, n, bias=1): | |
| """ | |
| Trains the perceptron algorithm | |
| :param data_set: Dataset for training | |
| :param n: learning rate | |
| :param bias: A threshold | |
| :return: Prints errors and hits | |
| """ | |
| count = 0 | |
| weights = None | |
| while len(data_set) > count: | |
| # Separate the entries of the expected classes | |
| *inputs, ce = data_set[count] | |
| # Add bias to entries | |
| inputs = inputs + [bias] | |
| # Generates weights for each entry if it does not exist | |
| if not weights: | |
| weights = [0.0 for _ in inputs] | |
| # Predict the class | |
| cp = perceptron(inputs, weights) | |
| # Calculate the error | |
| e = ce - cp | |
| # Check if the RNA is correct | |
| if e == 0: | |
| # Prints that is correct | |
| print(f'Input {count + 1} is correct') | |
| # Pass to next entry | |
| count += 1 | |
| else: | |
| # Recalculate weights | |
| calculate_weights(inputs, weights, n, e) | |
| # Prints that is incorrect | |
| print(f'Input {count + 1} is incorrect. Adjusting the weights...') | |
| # Back to first entry | |
| count = 0 | |
| if __name__ == '__main__': | |
| # x1 | x2 | class | |
| examples = ( | |
| (1.0, 1.0, 1), | |
| (9.4, 6.4, -1), | |
| (2.5, 2.1, 1), | |
| (8.0, 7.7, -1), | |
| (0.5, 2.2, 1), | |
| (7.9, 8.4, -1), | |
| (7.0, 7.0, -1), | |
| (2.8, 0.8, 1), | |
| (1.2, 3.0, 1), | |
| (7.8, 6.1, -1), | |
| ) | |
| training(examples, n=0.2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment