Created
October 4, 2023 13:40
-
-
Save ksasao/0d552ed637f302a0c990d751157abd03 to your computer and use it in GitHub Desktop.
SEFR multi-class classifier algorithm implementation for M5Atom
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
This code is derived from the SEFR multi-class classifier algorithm developed by Alan Wang. | |
The original code can be found at https://github.com/alankrantas/sefr_multiclass_classifier. | |
The code is licensed under the MIT License. | |
https://opensource.org/licenses/mit-license.php | |
*/ | |
#include "SEFR.h" | |
uint8_t m_features; | |
uint8_t m_labels; | |
float** m_weights = NULL; | |
float* m_bias = NULL; | |
void deleteBuffer(){ | |
if(m_weights != NULL){ | |
for (int i = 0; i < m_labels; i++) { | |
delete[] m_weights[i]; | |
} | |
delete[] m_weights; | |
} | |
if(m_bias != NULL){ | |
delete[] m_bias; | |
} | |
} | |
SEFR::SEFR() {} | |
SEFR::~SEFR() { | |
deleteBuffer(); | |
} | |
void SEFR::setup(uint8_t features, uint8_t labels) { | |
deleteBuffer(); | |
m_features = features; | |
m_labels = labels; | |
m_weights = new float*[m_labels]; | |
for (int i = 0; i < m_labels; i++) { | |
m_weights[i] = new float[m_features]; | |
for (int j=0;j<m_features;j++){ | |
m_weights[i][j] = 0; | |
} | |
} | |
m_bias = new float[m_labels]; | |
} | |
uint8_t SEFR::getFeatures() { | |
return m_features; | |
} | |
uint8_t SEFR::getLabels() { | |
return m_labels; | |
} | |
void SEFR::fit(float** dataset, uint8_t* target, int dataset_size) { | |
// iterate all labels | |
for (byte l = 0; l < m_labels; l++) { | |
unsigned int count_pos = 0, count_neg = 0; | |
// iterate all features | |
for (byte f = 0; f < m_features; f++) { | |
float avg_pos = 0, avg_neg = 0; | |
count_pos = 0; | |
count_neg = 0; | |
for (unsigned int s = 0; s < dataset_size; s++) { | |
if (target[s] != l) { // use "not the label" as positive class | |
avg_pos += dataset[s][f]; | |
count_pos++; | |
} else { // use the label as negative class | |
avg_neg += dataset[s][f]; | |
count_neg++; | |
} | |
} | |
avg_pos /= float(count_pos); | |
avg_neg /= float(count_neg); | |
// calculate weight of this label | |
m_weights[l][f] = (avg_pos - avg_neg) / (avg_pos + avg_neg); | |
} | |
// calculate average weighted score for positive/negative data | |
float avg_pos_w = 0.0, avg_neg_w = 0.0; | |
for (unsigned int s = 0; s < dataset_size; s++) { | |
float weighted_score = 0.0; | |
for (byte f = 0; f < m_features; f++) { | |
weighted_score += (dataset[s][f] * m_weights[l][f]); | |
} | |
if (target[s] != l) { | |
avg_pos_w += weighted_score; | |
} else { | |
avg_neg_w += weighted_score; | |
} | |
} | |
avg_pos_w /= float(count_pos); | |
avg_neg_w /= float(count_neg); | |
// calculate bias of this label | |
m_bias[l] = -1 * (float(count_neg) * avg_pos_w + float(count_pos) * avg_neg_w) / float(count_pos + count_neg); | |
} | |
} | |
// predict label from a single new data instance | |
uint8_t SEFR::predict(float new_data[]) { | |
float score[m_labels]; | |
for (byte l = 0; l < m_labels; l++) { | |
score[l] = 0.0; | |
for (byte f = 0; f < m_features; f++) { | |
// calculate weight of each labels | |
score[l] += (new_data[f] * m_weights[l][f]); | |
} | |
score[l] += m_bias[l]; // add bias of each labels | |
} | |
// find the min score (least possible label of "not the label") | |
float min_score = score[0]; | |
byte min_label = 0; | |
for (byte l = 1; l < m_labels; l++) { | |
if (score[l] < min_score) { | |
min_score = score[l]; | |
min_label = l; | |
} | |
} | |
return min_label; // return prediction | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
This code is derived from the SEFR multi-class classifier algorithm developed by Alan Wang. | |
The original code can be found at https://github.com/alankrantas/sefr_multiclass_classifier. | |
The code is licensed under the MIT License. | |
https://opensource.org/licenses/mit-license.php | |
*/ | |
#ifndef SEFR_H | |
#define SEFR_H | |
#include <Arduino.h> | |
class SEFR { | |
private: | |
uint8_t m_features; | |
uint8_t m_labels; | |
public: | |
SEFR(); | |
~SEFR(); | |
void setup(uint8_t features, uint8_t labels); | |
uint8_t getFeatures(); | |
uint8_t getLabels(); | |
void fit(float** dataset, uint8_t* target, int dataset_size); | |
uint8_t predict(float new_data[]); | |
}; | |
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
This code is derived from the SEFR multi-class classifier algorithm developed by Alan Wang. | |
The original code can be found at https://github.com/alankrantas/sefr_multiclass_classifier. | |
The code is licensed under the MIT License. | |
https://opensource.org/licenses/mit-license.php | |
*/ | |
#include "M5Atom.h" | |
#include "SEFR.h" | |
#define FEATURES 4 // number of features | |
#define LABELS 3 // number of labels | |
#define DATAFACTOR 10 // scale factor of data | |
SEFR sefr; | |
#define DATASET_MAXSIZE 150 // max dataset size (change this if you want to add data directly to onboard DATASET) | |
unsigned int dataset_size = 150; // current dataset size | |
// the Iris dataset (times DATAFACTOR so it can be stored as integer and save space/memory) | |
int DATASET[DATASET_MAXSIZE][FEATURES] = { | |
{ 51, 35, 14, 2 }, { 49, 30, 14, 2 }, { 47, 32, 13, 2 }, { 46, 31, 15, 2 }, { 50, 36, 14, 2 }, { 54, 39, 17, 4 }, { 46, 34, 14, 3 }, { 50, 34, 15, 2 }, { 44, 29, 14, 2 }, { 49, 31, 15, 1 }, { 54, 37, 15, 2 }, { 48, 34, 16, 2 }, { 48, 30, 14, 1 }, { 43, 30, 11, 1 }, { 58, 40, 12, 2 }, { 57, 44, 15, 4 }, { 54, 39, 13, 4 }, { 51, 35, 14, 3 }, { 57, 38, 17, 3 }, { 51, 38, 15, 3 }, { 54, 34, 17, 2 }, { 51, 37, 15, 4 }, { 46, 36, 10, 2 }, { 51, 33, 17, 5 }, { 48, 34, 19, 2 }, { 50, 30, 16, 2 }, { 50, 34, 16, 4 }, { 52, 35, 15, 2 }, { 52, 34, 14, 2 }, { 47, 32, 16, 2 }, { 48, 31, 16, 2 }, { 54, 34, 15, 4 }, { 52, 41, 15, 1 }, { 55, 42, 14, 2 }, { 49, 31, 15, 2 }, { 50, 32, 12, 2 }, { 55, 35, 13, 2 }, { 49, 36, 14, 1 }, { 44, 30, 13, 2 }, { 51, 34, 15, 2 }, { 50, 35, 13, 3 }, { 45, 23, 13, 3 }, { 44, 32, 13, 2 }, { 50, 35, 16, 6 }, { 51, 38, 19, 4 }, { 48, 30, 14, 3 }, { 51, 38, 16, 2 }, { 46, 32, 14, 2 }, { 53, 37, 15, 2 }, { 50, 33, 14, 2 }, { 70, 32, 47, 14 }, { 64, 32, 45, 15 }, { 69, 31, 49, 15 }, { 55, 23, 40, 13 }, { 65, 28, 46, 15 }, { 57, 28, 45, 13 }, { 63, 33, 47, 16 }, { 49, 24, 33, 10 }, { 66, 29, 46, 13 }, { 52, 27, 39, 14 }, { 50, 20, 35, 10 }, { 59, 30, 42, 15 }, { 60, 22, 40, 10 }, { 61, 29, 47, 14 }, { 56, 29, 36, 13 }, { 67, 31, 44, 14 }, { 56, 30, 45, 15 }, { 58, 27, 41, 10 }, { 62, 22, 45, 15 }, { 56, 25, 39, 11 }, { 59, 32, 48, 18 }, { 61, 28, 40, 13 }, { 63, 25, 49, 15 }, { 61, 28, 47, 12 }, { 64, 29, 43, 13 }, { 66, 30, 44, 14 }, { 68, 28, 48, 14 }, { 67, 30, 50, 17 }, { 60, 29, 45, 15 }, { 57, 26, 35, 10 }, { 55, 24, 38, 11 }, { 55, 24, 37, 10 }, { 58, 27, 39, 12 }, { 60, 27, 51, 16 }, { 54, 30, 45, 15 }, { 60, 34, 45, 16 }, { 67, 31, 47, 15 }, { 63, 23, 44, 13 }, { 56, 30, 41, 13 }, { 55, 25, 40, 13 }, { 55, 26, 44, 12 }, { 61, 30, 46, 14 }, { 58, 26, 40, 12 }, { 50, 23, 33, 10 }, { 56, 27, 42, 13 }, { 57, 30, 42, 12 }, { 57, 29, 42, 13 }, { 62, 29, 43, 13 }, { 51, 25, 30, 11 }, { 57, 28, 41, 13 }, { 63, 33, 60, 25 }, { 58, 27, 51, 19 }, { 71, 30, 59, 21 }, { 63, 29, 56, 18 }, { 65, 30, 58, 22 }, { 76, 30, 66, 21 }, { 49, 25, 45, 17 }, { 73, 29, 63, 18 }, { 67, 25, 58, 18 }, { 72, 36, 61, 25 }, { 65, 32, 51, 20 }, { 64, 27, 53, 19 }, { 68, 30, 55, 21 }, { 57, 25, 50, 20 }, { 58, 28, 51, 24 }, { 64, 32, 53, 23 }, { 65, 30, 55, 18 }, { 77, 38, 67, 22 }, { 77, 26, 69, 23 }, { 60, 22, 50, 15 }, { 69, 32, 57, 23 }, { 56, 28, 49, 20 }, { 77, 28, 67, 20 }, { 63, 27, 49, 18 }, { 67, 33, 57, 21 }, { 72, 32, 60, 18 }, { 62, 28, 48, 18 }, { 61, 30, 49, 18 }, { 64, 28, 56, 21 }, { 72, 30, 58, 16 }, { 74, 28, 61, 19 }, { 79, 38, 64, 20 }, { 64, 28, 56, 22 }, { 63, 28, 51, 15 }, { 61, 26, 56, 14 }, { 77, 30, 61, 23 }, { 63, 34, 56, 24 }, { 64, 31, 55, 18 }, { 60, 30, 48, 18 }, { 69, 31, 54, 21 }, { 67, 31, 56, 24 }, { 69, 31, 51, 23 }, { 58, 27, 51, 19 }, { 68, 32, 59, 23 }, { 67, 33, 57, 25 }, { 67, 30, 52, 23 }, { 63, 25, 50, 19 }, { 65, 30, 52, 20 }, { 62, 34, 54, 23 }, { 59, 30, 51, 18 } | |
}; | |
// labels of the Iris dataset | |
byte TARGET[DATASET_MAXSIZE] = { | |
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 | |
}; | |
unsigned int training_time; // mode training time | |
// ================================================== | |
void setup() { | |
M5.begin(true, false,true); | |
delay(10); | |
randomSeed(42); | |
sefr.setup(FEATURES, LABELS); | |
// データセット作成 | |
float** X_train = new float*[DATASET_MAXSIZE]; | |
uint8_t* y_train = new uint8_t[DATASET_MAXSIZE]; | |
for (int i = 0; i < DATASET_MAXSIZE; i++) { | |
X_train[i] = new float[FEATURES]; | |
for (int j = 0; j < FEATURES; j++) { | |
X_train[i][j] = float(DATASET[i][j])/float(DATAFACTOR); | |
Serial.print(X_train[i][j]); | |
Serial.print(" "); | |
} | |
y_train[i] = TARGET[i]; | |
Serial.print("=> "); | |
Serial.print(y_train[i]); | |
Serial.println(); | |
} | |
// 平均学習時間を計測 | |
long start = millis(); | |
for(long i=0;i<1000;i++){ | |
sefr.fit(X_train, y_train, DATASET_MAXSIZE); | |
} | |
long end = millis(); | |
Serial.print("Train: "); | |
Serial.print((end-start)/1000.0); | |
Serial.println(" ms"); | |
} | |
void loop() { | |
// randomly pick a random data instance in DATASET as test data | |
unsigned int test_index = random(dataset_size); | |
float test_data[FEATURES]; | |
byte test_label = TARGET[test_index]; | |
// 元のデータに揺らぎを加えたテストデータを作成 | |
Serial.print("Test data: "); | |
for (byte f = 0; f < FEATURES; f++) { | |
int sign = (random(0, 2) == 0) ? 1 : -1; | |
int change = int(DATASET[test_index][f] * float(random(4)) / 10.0); | |
test_data[f] = (DATASET[test_index][f] + change * sign) / float(DATAFACTOR); // randomly add or subtract 0-30% to each feature | |
Serial.print(test_data[f]); | |
Serial.print(" "); | |
} | |
Serial.println(); | |
// predict label | |
byte result_label = 0; | |
// 平均推論時間を計測 | |
long start = millis(); | |
for(long i=0;i<1000000;i++){ | |
result_label = sefr.predict(test_data); | |
} | |
long end = millis(); | |
Serial.print("Predict: "); | |
Serial.print((end-start)/1000000.0,5); | |
Serial.println(" ms"); | |
// compare the results | |
Serial.print("Predicted label: "); | |
Serial.print(result_label); | |
Serial.print(" / actual label: "); | |
Serial.print(test_label); | |
Serial.println(""); | |
delay(5000); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment