Skip to content

Instantly share code, notes, and snippets.

@agibsonccc
Created August 27, 2014 02:44
Show Gist options
  • Save agibsonccc/3c6e57d7fc5a9434058f to your computer and use it in GitHub Desktop.
Save agibsonccc/3c6e57d7fc5a9434058f to your computer and use it in GitHub Desktop.
package org.deeplearning4j.nn.api;
import org.deeplearning4j.linalg.api.ndarray.INDArray;
/**
* A classifier (this is for supervised learning)
*
* @author Adam Gibson
*/
public interface Classifier {
/**
* Returns the amount of error for each example
* @param examples te the examples to classify (one example in each row)
* @param labels the true labels
* @return the scores for each ndarray
*/
float[] score(INDArray examples,INDArray labels);
/**
* Returns the number of possible labels
* @return the number of possible labels for this classifier
*/
int numLabels();
/**
* Takes in a list of examples
* For each row, returns a label
* @param examples the examples to classify (one example in each row)
* @return the labels for each example
*/
int[] predict(INDArray examples);
/**
* Returns the probabilities for each label
* for each example row wise
* @param examples the examples to classify (one example in each row)
* @return the likelihoods of each example and each label
*/
INDArray labelProbabilites(INDArray examples);
/**
* Fit the model
* @param examples the examples to classify (one example in each row)
* @param labels the example labels(a binary outcome matrix)
*/
void fit(INDArray examples,INDArray labels);
/**
* Fit the model
* @param examples the examples to classify (one example in each row)
* @param labels the labels for each example (the number of labels must match
* the number of rows in the example
*/
void fit(INDArray examples,int[] labels);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment