This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Convert the given String[] to a Float[] | |
// Note: The `CSVReader` returns a row as String[] where each String is a number. We parse this String and convert it to | |
// a float. | |
fun convertStringArrayToFloatArray( strArray : Array<String> ) : FloatArray { | |
val out = strArray.map { si -> si.toFloat() }.toFloatArray() | |
return out | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Set data from `rawData` to `featureColumnData` | |
private fun populateColumns() { | |
// Create an empty ArrayList to store the labels | |
val labels = ArrayList<String>() | |
// Iterate through `rawData` starting from index=1 ( as index=0 refers to the column names ) | |
for ( strSample in rawData.subList( 1 , rawData.size ) ) { | |
// Append the label which is the last element of `strSample`. | |
labels.add( strSample[ numFeatures ] ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
... | |
// Number of features in the dataset. This number equals ( num_cols - 1 ) where num_cols is the number of | |
// columns in the CSV file. | |
// ( Note: We assume that the file has the labels column as the last column ). | |
var numFeatures = 0 | |
// Number of samples in the dataset. This number equals ( num_rows - 1 ) where num_rows is the number of | |
// rows in the CSV file. | |
// Note: We assume that the CSV file has its first row as the column names. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Get column names from `rawData`. | |
private fun getColumnNames() : Array<String> { | |
val columnNames = rawData[ 0 ] | |
return columnNames | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Helper class to load data from given CSV file and transform it into a Array<FeatureColumn> | |
class DataFrame( context: Context , assetsFileName : String ) { | |
// HashMap which stores the CSV data in the form ( Column_Name , Float[] ). Where Float[] holds | |
// the feature value for all samples. | |
private var featureColumnData = HashMap<String,ArrayList<Float>>() | |
// Variable to store the parsed CSV file from `CSVReader`. | |
private var rawData : List<Array<String>> |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Class to implement Gaussian Naive Bayes | |
class GaussianNB( private var dataFrame : DataFrame ) { | |
... | |
// Prior probabilities stored in a HashMap of form ( column_name , prior_prob ) | |
private var priorProbabilities : HashMap<String,Float> | |
... |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Holds data for a particular feature. | |
class FeatureColumn( var name : String , var data : FloatArray ) { | |
// Mean of given `data` | |
var featureMean : Float | |
// Variance of given `data` | |
var featureVariance : Float | |
// Standard deviation of given `data` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Predict a class for the given sample using the Random Forest. | |
fun predict( x : HashMap<String,String> ) : String { | |
// Create an empty array to store class labels. | |
val treeOutputs = Array( NUM_TREES ) { "" } | |
for ( i in 0 until NUM_TREES ) { | |
// Store the output of each DecisionTree in our forest. | |
treeOutputs[ i ] = forest[i].predict( x ) | |
println( "Prediction ${i+1} DecisionTree is ${treeOutputs[i]}") | |
} | |
// Get the majority label, which is our final prediction for the given sample. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
... | |
// The forest represented as an array of DecisionTree objects. | |
private var forest : ArrayList<DecisionTree> = ArrayList() | |
init { | |
... | |
// Initialize the forest |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Return samples ( in form of DataFrame object ) given their indices. | |
fun getEntries( indices : IntArray ) : DataFrame { | |
val dataFrame = DataFrame() | |
data.map { column -> | |
// `column` represent a Map -> ( String , ArrayList<String> ) | |
// column.key -> Name of the column as in the training datasets. | |
// column.value -> ArrayList<String> containing the column's data. | |
val columnData = ArrayList<String>() | |
val values = column.value | |
// Add the feature values corresponding to each index in `indices`. |