Skip to content

Instantly share code, notes, and snippets.

@jellea
Created May 21, 2016 15:37
Small testing env with MLP (GRT) and Openframeworks
#include "ofApp.h"
bool learning = false;
bool training = false;
bool running = false;
int selectedClass = 1;
int numInputs = 6;
float inputVals[6];
float outputVals[8];
void ofApp::randomizeVals() {
for(int i = 0; i < 8; i++) {
targetVector[i] = static_cast <float> (rand()) / static_cast <float> (RAND_MAX);
outputVals[i] = targetVector[i];
}
}
//--------------------------------------------------------------
void ofApp::setup(){
srand(static_cast <unsigned> (time(0)));
// randomizeVals();
targetVector.resize(8);
trainingData.setInputAndTargetDimensions(6,8);
MLP mlp;
unsigned int numInputNeurons = trainingData.getNumInputDimensions();
unsigned int numHiddenNeurons = 10;
unsigned int numOutputNeurons = 1; //1 as we are using multidimensional regression
//Initialize the MLP
mlp.init(numInputNeurons, numHiddenNeurons, numOutputNeurons, Neuron::LINEAR, Neuron::SIGMOID, Neuron::SIGMOID );
//Set the training settings
mlp.setMaxNumEpochs( 1000 ); //This sets the maximum number of epochs (1 epoch is 1 complete iteration of the training data) that are allowed
mlp.setMinChange( 1.0e-7 ); //This sets the minimum change allowed in training error between any two epochs
mlp.setLearningRate( 0.001 ); //This sets the rate at which the learning algorithm updates the weights of the neural network
mlp.setNumRandomTrainingIterations( 5 ); //This sets the number of times the MLP will be trained, each training iteration starts with new random values
mlp.setUseValidationSet( true ); //This sets aside a small portiion of the training data to be used as a validation set to mitigate overfitting
mlp.setValidationSetSize( 20 ); //Use 20% of the training data for validation during the training phase
mlp.setRandomiseTrainingOrder( true ); //Randomize the order of the training data so that the training algorithm does not bias the training
//The MLP generally works much better if the training and prediction data is first scaled to a common range (i.e. [0.0 1.0])
mlp.enableScaling( false );
pipeline << MultidimensionalRegression(mlp,true);
// OSC: listen on the given port
cout << "listening for osc messages on port " << PORT << "\n";
receiver.setup( PORT );
}
//--------------------------------------------------------------
void ofApp::update(){
while(receiver.hasWaitingMessages())
{
VectorFloat inputVector(6);
ofxOscMessage m;
receiver.getNextMessage(m);
// cout << "incoming message from address " << m.getAddress() << "and with num: "<< m.getNumArgs() << endl;
string vals = "inputVals:";
for(int i = 0; i < 6; i++) {
inputVals[i] = m.getArgAsFloat(i);
inputVector[i] = inputVals[i];
vals += " " + std::to_string(inputVals[i]);
}
// cout << vals;
if( learning ){
trainingData.addSample(inputVector, targetVector );
}
if(training){
if( pipeline.train( trainingData ) ){
cout << "Pipeline Trained";
training = false;
}else cout << "WARNING: Failed to train pipeline";
}
if (pipeline.getTrained() && training){
training = false;
}
if( pipeline.getTrained() && running ){
pipeline.predict(inputVector);
VectorFloat regressionData;
regressionData = pipeline.getRegressionData();
for(int i = 0; i < 8; i++) {
outputVals[i] = GRT::Util::limit(regressionData[i],0.0,1.0);
}
}
}
}
//--------------------------------------------------------------
void ofApp::draw(){
// HUD
// INPUT
ofDrawBitmapStringHighlight("Input", 40, 20);
ofDrawRectangle(40, 40, 180, 600);
string txt = "Class: " + std::to_string(selectedClass);
ofDrawBitmapStringHighlight(txt , 60, 70);
string learningYesNo = learning ? "yes" : "no";
ofDrawBitmapStringHighlight("(a) Learning? " + learningYesNo , 60, 90);
string trainingYesNo = training ? "yes" : "no";
ofDrawBitmapStringHighlight("(s) Training? " + trainingYesNo , 60, 110);
string runningYesNo = running ? "yes" : "no";
ofDrawBitmapStringHighlight("(d) Running? " + runningYesNo , 60, 130);
string invals;
for(int i = 0; i < 6; i++) {
invals += "\n" + std::to_string(inputVals[i]);
}
ofDrawBitmapStringHighlight("inputVals:" + invals, 60, 160);
// OUTPUT
ofDrawBitmapStringHighlight("Output", 260, 20);
ofDrawRectangle(260, 40, 180, 600);
ofDrawBitmapStringHighlight("(q) Randomize", 280, 70);
string vals;
for(int i = 0; i < 8; i++) {
vals += "\n" + std::to_string(outputVals[i]);
}
ofDrawBitmapStringHighlight("outputVals:" + vals, 280, 90);
}
//--------------------------------------------------------------
void ofApp::keyPressed(int key){
std::cout << "key value: " << key << endl;
if (key == 97){
learning = true;
running = false;
training = false;
}
if (key == 115){
training = true;
running = false;
learning = false;
}
if (key == 100){
training = false;
running = true;
}
if (key < 58 && key > 47) {
selectedClass = key - 48;
}
}
//--------------------------------------------------------------
void ofApp::keyReleased(int key){
if (key == 97){
learning = false;
}
if (key == 113){
running = false;
randomizeVals();
}
}
//--------------------------------------------------------------
void ofApp::mouseMoved(int x, int y ){
}
//--------------------------------------------------------------
void ofApp::mouseDragged(int x, int y, int button){
}
//--------------------------------------------------------------
void ofApp::mousePressed(int x, int y, int button){
}
//--------------------------------------------------------------
void ofApp::mouseReleased(int x, int y, int button){
}
//--------------------------------------------------------------
void ofApp::mouseEntered(int x, int y){
}
//--------------------------------------------------------------
void ofApp::mouseExited(int x, int y){
}
//--------------------------------------------------------------
void ofApp::windowResized(int w, int h){
}
//--------------------------------------------------------------
void ofApp::gotMessage(ofMessage msg){
}
//--------------------------------------------------------------
void ofApp::dragEvent(ofDragInfo dragInfo){
}
#pragma once
#include <cstdlib>
#include <ctime>
#include "ofMain.h"
#include "ofxGui.h"
#include "ofxGrt.h"
#include "ofxOsc.h"
#define PORT 6448
using namespace GRT;
class ofApp : public ofBaseApp{
private:
void randomizeVals();
public:
void setup();
void update();
void draw();
RegressionData trainingData; //This will store our training data
GestureRecognitionPipeline pipeline; //This is a wrapper for our classifier and any pre/post processing modules
VectorFloat targetVector;
void keyPressed(int key);
void keyReleased(int key);
void mouseMoved(int x, int y );
void mouseDragged(int x, int y, int button);
void mousePressed(int x, int y, int button);
void mouseReleased(int x, int y, int button);
void mouseEntered(int x, int y);
void mouseExited(int x, int y);
void windowResized(int w, int h);
void dragEvent(ofDragInfo dragInfo);
void gotMessage(ofMessage msg);
ofxOscReceiver receiver;
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment