Created
September 1, 2016 19:00
-
-
Save sumanyu/e82530f899c3447d8908adad7a8abcc7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class KNNImpute(k: Int) { | |
def impute(data: Array[Array[Double]]) = { | |
checkRowsAndColumnsAreNotEntirelyNAN(data) | |
//skip rows which are perfect | |
for (x <- data if nrOfMissingValues(x) != 0) { | |
val distanceBetweenXAndY = distanceBetweenRowAndAllOtherRows(data, x) | |
} | |
} | |
def distanceBetweenRowAndAllOtherRows(data: Array[Array[Double]], x: Array[Double]) = { | |
val distanceBetweenXAndY = Array.ofDim[Double](data.length) | |
val missingValues = nrOfMissingValues(x) | |
for (yIdx <- distanceBetweenXAndY.indices) { | |
val y = data(yIdx) | |
var n = 0 | |
distanceBetweenXAndY(yIdx) = 0.0 | |
for (m <- x.indices if !x(m).isNaN && !y(m).isNaN) { | |
n += 1 | |
val d = x(m) - y(m) | |
distanceBetweenXAndY(yIdx) += (d * d) | |
} | |
//weight the distance | |
distanceBetweenXAndY(yIdx) = if (n > (x.length - missingValues) / 2) | |
x.length * distanceBetweenXAndY(yIdx) | |
else | |
Double.MaxValue | |
} | |
distanceBetweenXAndY | |
} | |
def nrOfMissingValues(row: Array[Double]): Int = row.count(_.isNaN) | |
private def checkRowsAndColumnsAreNotEntirelyNAN(data: Array[Array[Double]]) = { | |
val missingValuesInColumns = getMissingValueCountsInColumns(data) | |
val nrOfRows = data.length | |
checkColumnIsEntirelyNAN(missingValuesInColumns, nrOfRows) | |
} | |
private def checkColumnIsEntirelyNAN(missingValuesInColumns: Array[Int], nrOfRows: Int) = { | |
missingValuesInColumns.foreach { missingValueInColumn => | |
if (missingValueInColumn == nrOfRows) | |
throw new Exception("The whole column is missing") | |
} | |
} | |
private def getMissingValueCountsInColumns(data: Array[Array[Double]]) = { | |
val nrOfColumns = data(0).length | |
val missingValuesInColumns = Array.ofDim[Int](nrOfColumns) | |
for (row <- data.indices) { | |
var n = 0 | |
for (column <- data(row).indices if data(row)(column).isNaN) { | |
n += 1 | |
missingValuesInColumns(column) += 1 | |
} | |
if (n == data(row).length) { | |
throw new Exception("The whole row " + row + " is missing") | |
} | |
} | |
missingValuesInColumns | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment