Skip to content

Instantly share code, notes, and snippets.

@mrbid
Last active April 17, 2024 06:52
Show Gist options
  • Save mrbid/f01de7eb6792fda6b61675342a66c4c1 to your computer and use it in GitHub Desktop.
Save mrbid/f01de7eb6792fda6b61675342a66c4c1 to your computer and use it in GitHub Desktop.
A custom brute force alternative to backpropagation.
/*
Test_User (notabug.org/Test_User)
April 2024
A custom brute force alternative to backpropagation for training of FFN/MLP forward-pass process.
Best generated weights from this process are available to download here:
https://raw.githubusercontent.com/mrbid/mrbid.github.io/main/fprop_best_weights_save
The network trains on the Zodiac dataset:
https://github.com/jcwml/neural_zodiac
sint8 range: -128 to 127
gcc rprop.c -Ofast -ggdb3 -o fprop
valgrind -s --track-origins=yes --leak-check=full --show-leak-kinds=all ./fprop
*/
#include <fcntl.h>
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include <signal.h>
#include <time.h>
#include <unistd.h>
#define uint unsigned int
#define sint8 int8_t // signed int8
#define NUM_INPUTS 12
#define NUM_OUTPUTS 1
#define UNITS 16
#define ReLU_LEAK 0
// how many cycles without change before we start moving randomly to values in the corresponding layer with equal error results (assuming no better is found)
#define UNCHANGED_CYCLE_LIMIT_4 1
#define UNCHANGED_CYCLE_LIMIT_3 UNCHANGED_CYCLE_LIMIT_4 * 2
#define UNCHANGED_CYCLE_LIMIT_2 UNCHANGED_CYCLE_LIMIT_3 * 2
#define UNCHANGED_CYCLE_LIMIT_1 UNCHANGED_CYCLE_LIMIT_2 * 2
// how many unchanged cycles before we brute force all 256 values of sint8 for each, helps escape local minima and flat spots that rand isn't reaching the edge of
#define UNCHANGED_CYCLE_LIMIT_BREAKOUT UNCHANGED_CYCLE_LIMIT_1 * 3
sint8 cap_multiply(sint8 x, sint8 y) {
sint8 res;
if (__builtin_mul_overflow(x, y, &res)) {
x = x > 0 ? 1 : -1;
y = y > 0 ? 1 : -1;
return x*y > 0 ? 127 : -128;
}
return res;
}
sint8 cap_add(sint8 x, sint8 y) {
sint8 res;
if (__builtin_add_overflow(x, y, &res)) {
return x > 0 ? 127 : -128;
}
return res;
}
sint8 cap_sub(sint8 x, sint8 y) {
sint8 res;
if (__builtin_sub_overflow(x, y, &res)) {
return x >= 0 ? 127 : -128;
}
return res;
}
typedef struct
{
// weights (units, weights)
sint8 l1w[UNITS][NUM_INPUTS];
sint8 l2w[UNITS][UNITS];
sint8 l3w[UNITS][UNITS];
sint8 l4w[NUM_OUTPUTS][UNITS];
// bias' (units, bias)
sint8 l1b[UNITS];
sint8 l2b[UNITS];
sint8 l3b[UNITS];
sint8 l4b[NUM_OUTPUTS];
} network;
network net; // define network as net
network net_old; // for stat dumping on ^C
// unused
void dumpWeights()
{
puts("l1w:");
for(int i = 0; i < UNITS; i++)
{
for(int j = 0; j < NUM_INPUTS; j++)
{
printf("%i ", net.l1w[i][j]);
}
}
printf("\n\n");
puts("l2w:");
for(int i = 0; i < UNITS; i++)
{
for(int j = 0; j < UNITS; j++)
{
printf("%i ", net.l2w[i][j]);
}
}
printf("\n\n");
puts("l3w:");
for(int i = 0; i < UNITS; i++)
{
for(int j = 0; j < UNITS; j++)
{
printf("%i ", net.l3w[i][j]);
}
}
printf("\n\n");
puts("l4w:");
for (int i = 0; i < NUM_OUTPUTS; i++)
{
for(int j = 0; j < UNITS; j++)
{
printf("%i ", net.l3w[i][j]);
}
}
printf("\n\n");
}
// unused
void layerStat()
{
int min=999, max=0, avg=0, avgd=0;
for(int i = 0; i < UNITS; i++)
{
for(int j = 0; j < NUM_INPUTS; j++)
{
if(net.l1w[i][j] > max){max = net.l1w[i][j];}
if(net.l1w[i][j] < min){min = net.l1w[i][j];}
avg += net.l1w[i][j];
avgd++;
}
}
avg /= avgd;
printf("l1w: %i %i %i [%i]\n", min, avg/avgd, max, avg);
min=999, max=0, avg=0, avgd=0;
for(int i = 0; i < UNITS; i++)
{
for(int j = 0; j < UNITS; j++)
{
if(net.l2w[i][j] > max){max = net.l2w[i][j];}
if(net.l2w[i][j] < min){min = net.l2w[i][j];}
avg += net.l2w[i][j];
avgd++;
}
}
avg /= avgd;
printf("l2w: %i %i %i [%i]\n", min, avg/avgd, max, avg);
min=999, max=0, avg=0, avgd=0;
for(int i = 0; i < UNITS; i++)
{
for(int j = 0; j < UNITS; j++)
{
if(net.l3w[i][j] > max){max = net.l3w[i][j];}
if(net.l3w[i][j] < min){min = net.l3w[i][j];}
avg += net.l3w[i][j];
avgd++;
}
}
avg /= avgd;
printf("l3w: %i %i %i [%i]\n", min, avg/avgd, max, avg);
min=999, max=0, avg=0, avgd=0;
for (int i = 0; i < NUM_OUTPUTS; i++)
{
for(int j = 0; j < UNITS; j++)
{
if(net.l4w[i][j] > max){max = net.l4w[i][j];}
if(net.l4w[i][j] < min){min = net.l4w[i][j];}
avg += net.l4w[i][j];
avgd++;
}
}
avg /= avgd;
printf("l3w: %i %i %i [%i]\n", min, avg/avgd, max, avg);
}
sint8 ReLU(sint8 s) // ReLU with leaky hyper-parameter
{
uint8_t t = s < ReLU_LEAK ? ReLU_LEAK : s;
return t < 16 ? t : 16;
}
// unused?
sint8 ReLU_D(sint8 s) // ReLU derivative
{
return s > 0 ? 1 : 0;
}
// unused
sint8 RPROP(signed int input, signed int error) // RPROP optimiser
{
return (input * error) < 0 ? -1 : 1;
}
// unused
sint8 srnd8() // random signed int8
{
return (sint8)((sint8)(rand()%255))-128;
}
// unused
sint8 srnd8_weight() // random signed int8 weight
{
sint8 r=0;
while(r == 0){r = srnd8();}
return r;
}
//sint8 training_data[][NUM_INPUTS+NUM_OUTPUTS] = {
// // i, x, y
// {0, 0, 0},
// {2, 0, 0},
// {4, 0, 1},
// {6, 2, 2},
// {8, 3, 2},
// {10, 4, 1},
// {12, 6, 0},
// {14, 6, -1},
// {16, 6, -3},
// {18, 6, -6},
// {20, 4, -8},
// {22, 2, -10},
// {24, 0, -12},
// {26, -3, -12},
// {28, -6, -12},
// {30, -10, -10},
// {32, -13, -8},
// {34, -16, -4},
// {36, -18, 0},
// {38, -18, 4},
// {40, -17, 10},
// {42, -14, 14},
// {44, -11, 19},
// {46, -5, 22},
// {48, 0, 24},
// {50, 6, 24},
// {52, 13, 22},
// {54, 19, 19},
// {56, 24, 14},
// {58, 28, 7},
// {60, 30, 0},
// {62, 29, -8},
// {64, 27, -15},
// {66, 23, -23},
// {68, 16, -29},
// {70, 9, -33},
// {72, 0, -36},
// {74, -9, -35},
// {76, -18, -32},
// {78, -27, -27},
// {80, -34, -19},
// {82, -39, -10},
// {84, -42, 0},
// {86, -41, 11},
// {88, -38, 21},
// {90, -31, 31},
// {92, -23, 39},
// {94, -12, 45},
// {96, 0, 48},
// {98, 12, 47},
// {100, 25, 43},
// {102, 36, 36},
// {104, 45, 25},
// {106, 51, 13},
// {108, 54, 0},
// {110, 53, -14},
// {112, 48, -27},
// {114, 40, -40},
// {116, 29, -50},
// {118, 15, -56},
// {120, 0, -60},
// {122, -15, -58},
// {124, -30, -53},
// {126, -44, -44},
//};
sint8 training_data[][NUM_INPUTS+NUM_OUTPUTS] = {
// aries, taurus, gemini, cancer, leo, virgo, libra, scorpio, sagittarius, capricorn, aquarius, pisces, output
{ 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 75},
{ 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 63},
{ 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 74},
{ 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 47},
{ 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 83},
{ 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 42},
{ 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 62},
{ 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 48},
{ 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 87},
{ 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 38},
{ 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 68},
{ 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 29},
{ 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 86},
{ 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 23},
{ 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 91},
{ 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 29},
{ 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 73},
{ 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 33},
{ 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 89},
{ 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 31},
{ 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 89},
{ 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 11},
{ 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 88},
{ 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 83},
{ 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 21},
{ 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 82},
{ 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 40},
{ 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 78},
{ 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 15},
{ 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 92},
{ 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 15},
{ 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 85},
{ 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 10},
{ 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 85},
{ 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 29},
{ 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 77},
{ 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 28},
{ 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 79},
{ 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 27},
{ 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 84},
{ 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 31},
{ 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 72},
{ 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 78},
{ 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 35},
{ 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 75},
{ 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 29},
{ 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 75},
{ 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 27},
{ 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 89},
{ 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 14},
{ 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 65},
{ 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 30},
{ 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 76},
{ 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 32},
{ 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 77},
{ 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 30},
{ 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 86},
{ 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 68},
{ 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 29},
{ 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 71},
{ 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 34},
{ 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 68},
{ 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 29},
{ 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 66},
{ 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 30},
{ 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 64},
{ 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 30},
{ 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 81},
{ 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 74},
{ 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 38},
{ 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 83},
{ 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 50},
{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 62},
{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 37},
{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 76},
{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 74},
{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 38},
{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 73},
};
// same as usual
void forward(sint8 input[NUM_INPUTS], sint8 *output)
{
// ---
// --------
// ---
// forward pass
sint8 l1o[UNITS];
for(uint i = 0; i < UNITS; i++) {
sint8 wa = 0;
for (uint j = 0; j < NUM_INPUTS; j++)
wa = cap_add(wa, cap_multiply(input[j], net.l1w[i][j]));
l1o[i] = ReLU(cap_add(wa, net.l1b[i]));
}
sint8 l2o[UNITS];
for(uint i = 0; i < UNITS; i++)
{
sint8 wa = 0;
for(uint j = 0; j < UNITS; j++)
wa = cap_add(wa, cap_multiply(l1o[j], net.l2w[i][j]));
l2o[i] = ReLU(cap_add(wa, net.l2b[i]));
}
sint8 l3o[UNITS];
for(uint i = 0; i < UNITS; i++)
{
sint8 wa = 0;
for(uint j = 0; j < UNITS; j++)
wa = cap_add(wa, cap_multiply(l2o[j], net.l3w[i][j]));
l3o[i] = ReLU(cap_add(wa, net.l3b[i]));
}
for (uint i = 0; i < NUM_OUTPUTS; i++)
{
sint8 wa = 0;
for(uint j = 0; j < UNITS; j++)
wa = cap_add(wa, cap_multiply(l3o[j], net.l4w[0][j]));
output[i] = cap_add(wa, net.l4b[0]);
}
return;
}
uint64_t SE()
{
uint64_t worst = 0;
uint64_t total_loss = 0;
uint64_t total_squared_loss = 0;
for (size_t i = 0; i < sizeof(training_data)/sizeof(training_data[0]); i++) { // for each training piece
sint8 loss[NUM_OUTPUTS];
forward(training_data[i], loss); // actually run the network
for (uint j = 0; j < NUM_OUTPUTS; j++) {
int cur_loss = (int)loss[j] - training_data[i][NUM_INPUTS+j];
cur_loss = cur_loss * (((cur_loss > 0) * 2) - 1); // calculate loss for this output, absolute value
if (cur_loss > worst)
worst = cur_loss;
total_loss += cur_loss;
total_squared_loss += cur_loss * cur_loss;
}
}
// return total_loss;
// return total_squared_loss;
// return worst;
return total_squared_loss + (worst * worst); // extra emphasis on getting rid of the "bad" answers
}
// above, except printing info, nothing really to see here
void dump_SE()
{
uint64_t worst = 0;
uint64_t total_loss = 0;
uint64_t total_squared_loss = 0;
for (size_t i = 0; i < sizeof(training_data)/sizeof(training_data[0]); i++) {
sint8 loss[NUM_OUTPUTS];
forward(training_data[i], loss);
for (uint j = 0; j < NUM_OUTPUTS; j++) {
int cur_loss = (int)loss[j] - training_data[i][NUM_INPUTS+j];
cur_loss = cur_loss * (((cur_loss > 0) * 2) - 1);
if (cur_loss > worst)
worst = cur_loss;
total_loss += cur_loss;
total_squared_loss += cur_loss * cur_loss;
printf("%d, ", cur_loss);
}
}
printf("[%lu] (%lu) {%lu}\n", total_loss, total_squared_loss, worst);
}
char *filepath;
void save()
{
int fd = open(filepath, O_WRONLY | O_CREAT, 0600);
write(fd, &net, sizeof(net));
close(fd);
}
// blahblahblah doesn't matter
void stats_and_exit()
{
memcpy(&net, &net_old, sizeof(net));
printf("\r");
for (size_t i = 0; i < sizeof(training_data)/sizeof(training_data[0]); i++) {
for (size_t j = 0; j < NUM_INPUTS; j++) {
switch (training_data[i][j]) {
case 0:
printf("0");
break;
case 1:
printf("1");
break;
case 2:
printf("2");
break;
}
}
printf(": ");
sint8 output[NUM_OUTPUTS];
forward(training_data[i], output);
for (size_t j = 0; j < NUM_OUTPUTS; j++)
printf("%d%% ", output[j]);
printf("(expected");
for (size_t j = 0; j < NUM_OUTPUTS; j++)
printf(" %d%%", training_data[i][NUM_INPUTS+j]);
printf(")\n");
}
if (filepath != NULL)
save();
exit(0);
}
int main(int argc, char **argv)
{
if (argc < 2) {
puts("No output file specified.");
} else {
filepath = argv[1];
}
signal(SIGINT, stats_and_exit);
// seed rand
srand(time(0));
// set weights to 0
for(uint i = 0; i < UNITS; i++)
for(uint j = 0; j < NUM_INPUTS; j++) {net.l1w[i][j] = 0; net.l1b[i] = 0;}
for(uint i = 0; i < UNITS; i++)
for(uint j = 0; j < UNITS; j++) {net.l2w[i][j] = 0; net.l2b[i] = 0;}
for(uint i = 0; i < UNITS; i++)
for(uint j = 0; j < UNITS; j++) {net.l3w[i][j] = 0; net.l3b[i] = 0;}
for(uint i = 0; i < UNITS; i++) {net.l4w[0][i] = 0;}
net.l4b[0] = 0;
//dumpWeights();
uint64_t cycles_unchanged = 0; // for when to behave how
uint64_t last_loss = -1UL; // maximum possible value
// train
while (1)
{
uint64_t loss = SE(); // get current loss value... I could prob move this out and reuse last one for better efficience but meh whatever (and *not* last_loss, since that one tracks when to reset unchanged_cycles)
if (cycles_unchanged > UNCHANGED_CYCLE_LIMIT_BREAKOUT) { // if it's time for more aggressive behavior
for (uint i = 0; i < UNITS; i++) { // for each first layer output
for (uint j = 0; j < NUM_INPUTS; j++) { // for each first layer input
uint64_t losses[256];
for (uint x = 0; x < 256; x++) { // get losses for all values
net.l1w[i][j] = x - 128;
losses[x] = SE();
}
uint64_t best = losses[0];
for (uint x = 1; x < 256; x++) { // select best
if (losses[x] < best)
best = losses[x];
}
sint8 best_losses[256];
uint64_t num_best_losses = 0;
for (uint x = 0; x < 256; x++) { // get a list of all that are infact the best
if (losses[x] == best) {
best_losses[num_best_losses] = x - 128;
num_best_losses++;
}
}
net.l1w[i][j] = best_losses[rand()%num_best_losses]; // pick one at random
}
uint64_t losses[256]; // repeat for each first layer output bias
for (uint x = 0; x < 256; x++) {
net.l1b[i] = x - 128;
losses[x] = SE();
}
uint64_t best = losses[0];
for (uint x = 1; x < 256; x++) {
if (losses[x] < best)
best = losses[x];
}
sint8 best_losses[256];
uint64_t num_best_losses = 0;
for (uint x = 0; x < 256; x++) {
if (losses[x] == best) {
best_losses[num_best_losses] = x - 128;
num_best_losses++;
}
}
net.l1b[i] = best_losses[rand()%num_best_losses];
}
// repeat for layer 2-4
for (uint i = 0; i < UNITS; i++) {
for (uint j = 0; j < UNITS; j++) {
uint64_t losses[256];
for (uint x = 0; x < 256; x++) {
net.l2w[i][j] = x - 128;
losses[x] = SE();
}
uint64_t best = losses[0];
for (uint x = 1; x < 256; x++) {
if (losses[x] < best)
best = losses[x];
}
sint8 best_losses[256];
uint64_t num_best_losses = 0;
for (uint x = 0; x < 256; x++) {
if (losses[x] == best) {
best_losses[num_best_losses] = x - 128;
num_best_losses++;
}
}
net.l2w[i][j] = best_losses[rand()%num_best_losses];
}
uint64_t losses[256];
for (uint x = 0; x < 256; x++) {
net.l2b[i] = x - 128;
losses[x] = SE();
}
uint64_t best = losses[0];
for (uint x = 1; x < 256; x++) {
if (losses[x] < best)
best = losses[x];
}
sint8 best_losses[256];
uint64_t num_best_losses = 0;
for (uint x = 0; x < 256; x++) {
if (losses[x] == best) {
best_losses[num_best_losses] = x - 128;
num_best_losses++;
}
}
net.l2b[i] = best_losses[rand()%num_best_losses];
}
for (uint i = 0; i < UNITS; i++) {
for (uint j = 0; j < UNITS; j++) {
uint64_t losses[256];
for (uint x = 0; x < 256; x++) {
net.l3w[i][j] = x - 128;
losses[x] = SE();
}
uint64_t best = losses[0];
for (uint x = 1; x < 256; x++) {
if (losses[x] < best)
best = losses[x];
}
sint8 best_losses[256];
uint64_t num_best_losses = 0;
for (uint x = 0; x < 256; x++) {
if (losses[x] == best) {
best_losses[num_best_losses] = x - 128;
num_best_losses++;
}
}
net.l3w[i][j] = best_losses[rand()%num_best_losses];
}
uint64_t losses[256];
for (uint x = 0; x < 256; x++) {
net.l3b[i] = x - 128;
losses[x] = SE();
}
uint64_t best = losses[0];
for (uint x = 1; x < 256; x++) {
if (losses[x] < best)
best = losses[x];
}
sint8 best_losses[256];
uint64_t num_best_losses = 0;
for (uint x = 0; x < 256; x++) {
if (losses[x] == best) {
best_losses[num_best_losses] = x - 128;
num_best_losses++;
}
}
net.l3b[i] = best_losses[rand()%num_best_losses];
}
for (uint i = 0; i < NUM_OUTPUTS; i++) {
for (uint j = 0; j < UNITS; j++) {
uint64_t losses[256];
for (uint x = 0; x < 256; x++) {
net.l4w[i][j] = x - 128;
losses[x] = SE();
}
uint64_t best = losses[0];
for (uint x = 1; x < 256; x++) {
if (losses[x] < best)
best = losses[x];
}
sint8 best_losses[256];
uint64_t num_best_losses = 0;
for (uint x = 0; x < 256; x++) {
if (losses[x] == best) {
best_losses[num_best_losses] = x - 128;
num_best_losses++;
}
}
net.l4w[i][j] = best_losses[rand()%num_best_losses];
}
uint64_t losses[256];
for (uint x = 0; x < 256; x++) {
net.l4b[i] = x - 128;
losses[x] = SE();
}
uint64_t best = losses[0];
for (uint x = 1; x < 256; x++) {
if (losses[x] < best)
best = losses[x];
}
sint8 best_losses[256];
uint64_t num_best_losses = 0;
for (uint x = 0; x < 256; x++) {
if (losses[x] == best) {
best_losses[num_best_losses] = x - 128;
num_best_losses++;
}
}
net.l4b[i] = best_losses[rand()%num_best_losses];
}
// end of force
} else {
for (uint i = 0; i < UNITS; i++) { // for each first layer output
for (uint j = 0; j < NUM_INPUTS; j++) { // for each first layer input
sint8 old = net.l1w[i][j];
net.l1w[i][j] = cap_add(old, 1); // +1 of original
uint64_t loss_a = SE(); // get loss
net.l1w[i][j] = cap_sub(old, 1); // -1 of original
uint64_t loss_b = SE(); // get loss
if (loss < loss_a && loss < loss_b || (loss <= loss_a && loss <= loss_b && cycles_unchanged < UNCHANGED_CYCLE_LIMIT_1)) { // if neither improves the situation || (neither is better && it's not yet time to explore randomly)
net.l1w[i][j] = old;
} else if (loss_a < loss_b) { // if a is better
loss = loss_a;
net.l1w[i][j] = cap_add(old, 1); // set to the better (+1)
} else if (loss_a == loss_b && (rand() % 2) == 1) { // 50/50 chance
loss = loss_a;
net.l1w[i][j] = cap_add(old, 1);
} else { // it's already set to b, and b is better or equal
loss = loss_b;
}
}
sint8 old = net.l1b[i]; // repeat for first layer output bias
net.l1b[i] = cap_add(old, 1);
uint64_t loss_a = SE();
net.l1b[i] = cap_sub(old, 1);
uint64_t loss_b = SE();
if (loss < loss_a && loss < loss_b || (loss <= loss_a && loss <= loss_b && cycles_unchanged < UNCHANGED_CYCLE_LIMIT_1)) {
net.l1b[i] = old;
} else if (loss_a < loss_b) {
loss = loss_a;
net.l1b[i] = cap_add(old, 1);
} else if (loss_a == loss_b && (rand() % 2) == 1) { // 50/50 chance
loss = loss_a;
net.l1b[i] = cap_add(old, 1);
} else {
loss = loss_b;
}
}
// repeat for layer 2-4
for (uint i = 0; i < UNITS; i++) {
for (uint j = 0; j < UNITS; j++) {
sint8 old = net.l2w[i][j];
net.l2w[i][j] = cap_add(old, 1);
uint64_t loss_a = SE();
net.l2w[i][j] = cap_sub(old, 1);
uint64_t loss_b = SE();
if (loss < loss_a && loss < loss_b || (loss <= loss_a && loss <= loss_b && cycles_unchanged < UNCHANGED_CYCLE_LIMIT_2)) {
net.l2w[i][j] = old;
} else if (loss_a < loss_b) {
loss = loss_a;
net.l2w[i][j] = cap_add(old, 1);
} else if (loss_a == loss_b && (rand() % 2) == 1) { // 50/50 chance
loss = loss_a;
net.l2w[i][j] = cap_add(old, 1);
} else {
loss = loss_b;
}
}
sint8 old = net.l2b[i];
net.l2b[i] = cap_add(old, 1);
uint64_t loss_a = SE();
net.l2b[i] = cap_sub(old, 1);
uint64_t loss_b = SE();
if (loss < loss_a && loss < loss_b || (loss <= loss_a && loss <= loss_b && cycles_unchanged < UNCHANGED_CYCLE_LIMIT_2)) {
net.l2b[i] = old;
} else if (loss_a < loss_b) {
loss = loss_a;
net.l2b[i] = cap_add(old, 1);
} else if (loss_a == loss_b && (rand() % 2) == 1) { // 50/50 chance
loss = loss_a;
net.l2b[i] = cap_add(old, 1);
} else {
loss = loss_b;
}
}
for (uint i = 0; i < UNITS; i++) {
for (uint j = 0; j < UNITS; j++) {
sint8 old = net.l3w[i][j];
net.l3w[i][j] = cap_add(old, 1);
uint64_t loss_a = SE();
net.l3w[i][j] = cap_sub(old, 1);
uint64_t loss_b = SE();
if (loss < loss_a && loss < loss_b || (loss <= loss_a && loss <= loss_b && cycles_unchanged < UNCHANGED_CYCLE_LIMIT_3)) {
net.l3w[i][j] = old;
} else if (loss_a < loss_b) {
loss = loss_a;
net.l3w[i][j] = cap_add(old, 1);
} else if (loss_a == loss_b && (rand() % 2) == 1) { // 50/50 chance
loss = loss_a;
net.l3w[i][j] = cap_add(old, 1);
} else {
loss = loss_b;
}
}
sint8 old = net.l3b[i];
net.l3b[i] = cap_add(old, 1);
uint64_t loss_a = SE();
net.l3b[i] = cap_sub(old, 1);
uint64_t loss_b = SE();
if (loss < loss_a && loss < loss_b || (loss <= loss_a && loss <= loss_b && cycles_unchanged < UNCHANGED_CYCLE_LIMIT_3)) {
net.l3b[i] = old;
} else if (loss_a < loss_b) {
loss = loss_a;
net.l3b[i] = cap_add(old, 1);
} else if (loss_a == loss_b && (rand() % 2) == 1) { // 50/50 chance
loss = loss_a;
net.l3b[i] = cap_add(old, 1);
} else {
loss = loss_b;
}
}
for (uint i = 0; i < NUM_OUTPUTS; i++) {
for (uint j = 0; j < UNITS; j++) {
sint8 old = net.l4w[i][j];
net.l4w[i][j] = cap_add(old, 1);
uint64_t loss_a = SE();
net.l4w[i][j] = cap_sub(old, 1);
uint64_t loss_b = SE();
if (loss < loss_a && loss < loss_b || (loss <= loss_a && loss <= loss_b && cycles_unchanged < UNCHANGED_CYCLE_LIMIT_4)) {
net.l4w[i][j] = old;
} else if (loss_a < loss_b) {
loss = loss_a;
net.l4w[i][j] = cap_add(old, 1);
} else if (loss_a == loss_b && (rand() % 2) == 1) { // 50/50 chance
loss = loss_a;
net.l4w[i][j] = cap_add(old, 1);
} else {
loss = loss_b;
}
}
sint8 old = net.l4b[i];
net.l4b[i] = cap_add(old, 1);
uint64_t loss_a = SE();
net.l4b[i] = cap_sub(old, 1);
uint64_t loss_b = SE();
if (loss < loss_a && loss < loss_b || (loss <= loss_a && loss <= loss_b && cycles_unchanged < UNCHANGED_CYCLE_LIMIT_4)) {
net.l4b[i] = old;
} else if (loss_a < loss_b) {
loss = loss_a;
net.l4b[i] = cap_add(old, 1);
} else if (loss_a == loss_b && (rand() % 2) == 1) { // 50/50 chance
loss = loss_a;
net.l4b[i] = cap_add(old, 1);
} else {
loss = loss_b;
}
}
}
// printf("loss: %lu\n", loss);
dump_SE();
memcpy(&net_old, &net, sizeof(net));
if (loss == 0) {
return 0;
}
if (last_loss != loss) {
cycles_unchanged = 0;
last_loss = loss;
} else {
cycles_unchanged++;
}
printf("%lu; 4: %d; 3: %d; 2: %d; 1: %d; break: %d\n", cycles_unchanged, UNCHANGED_CYCLE_LIMIT_4, UNCHANGED_CYCLE_LIMIT_3, UNCHANGED_CYCLE_LIMIT_2, UNCHANGED_CYCLE_LIMIT_1, UNCHANGED_CYCLE_LIMIT_BREAKOUT);
}
// done
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment