Skip to content

Instantly share code, notes, and snippets.

@agibsonccc
Created August 27, 2014 02:51
Show Gist options
  • Save agibsonccc/e6941a55591e7fe7baf8 to your computer and use it in GitHub Desktop.
Save agibsonccc/e6941a55591e7fe7baf8 to your computer and use it in GitHub Desktop.
package org.deeplearning4j.nn.api;
import org.deeplearning4j.linalg.api.ndarray.INDArray;
import org.deeplearning4j.linalg.dataset.DataSet;
/**
* A classifier (this is for supervised learning)
*
* @author Adam Gibson
*/
public interface Classifier {
/**
* Sets the input and labels and returns a score for the prediction
* wrt true labels
* @param data the data to score
* @return the score for the given input,label pairs
*/
float score(DataSet data);
/**
* Returns the f1 score for the given examples.
* Think of this to be like a percentage right.
* The higher the number the more it got right.
* This is on a scale from 0 to 1.
* @param examples te the examples to classify (one example in each row)
* @param labels the true labels
* @return the scores for each ndarray
*/
float score(INDArray examples,INDArray labels);
/**
* Returns the number of possible labels
* @return the number of possible labels for this classifier
*/
int numLabels();
/**
* Takes in a list of examples
* For each row, returns a label
* @param examples the examples to classify (one example in each row)
* @return the labels for each example
*/
int[] predict(INDArray examples);
/**
* Returns the probabilities for each label
* for each example row wise
* @param examples the examples to classify (one example in each row)
* @return the likelihoods of each example and each label
*/
INDArray labelProbabilites(INDArray examples);
/**
* Fit the model
* @param examples the examples to classify (one example in each row)
* @param labels the example labels(a binary outcome matrix)
*/
void fit(INDArray examples,INDArray labels);
/**
* Fit the model
* @param data the data to train on
*/
void fit(DataSet data);
/**
* Fit the model
* @param examples the examples to classify (one example in each row)
* @param labels the labels for each example (the number of labels must match
* the number of rows in the example
*/
void fit(INDArray examples,int[] labels);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment