Created
February 23, 2023 18:34
-
-
Save ancientstraits/936fecd67681fba309c76be57fe3b945 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stdio.h> | |
#include <stdlib.h> | |
#include <math.h> | |
#define ROUNDS 100 | |
float outval(float w, float x) { | |
return w * x; | |
} | |
float err(float a, float b) { | |
return (a-b)*(a-b); | |
} | |
float avg_err(float w) { | |
const float dx = 0.001; | |
int num_subdiv = 0; | |
float sum = 0.0; | |
for (float x = 0.0; x < 1.0; x += dx, num_subdiv++) { | |
// got: outval(w, x) | |
// expected: x | |
sum += err(outval(w, x), x); | |
} | |
return sum / (float)num_subdiv; | |
} | |
float err_slope(float w) { | |
const float dw = 0.0001; | |
float err1 = avg_err(w), err2 = avg_err(w + dw); | |
return (err2 - err1) / dw; | |
} | |
// returns a new value of w. | |
float epoch(float w) { | |
const float learn_rate = 1.0; | |
float slope = err_slope(w); | |
printf("slope = %f; ", slope); | |
// found local minimum | |
return w - (learn_rate * slope); | |
} | |
int main() { | |
float w = 0.0; | |
int weight_correct; | |
for (int i = 0; i < ROUNDS; i++) { | |
printf("Round %d: w is %f; ", i, w); | |
w = epoch(w); | |
printf("w becomes %f\n", w); | |
} | |
printf("Final model: outval(0.3) = %f\n", outval(w, 0.3)); | |
return 0; | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment