Created
June 15, 2016 22:11
-
-
Save sirolf2009/0c4d93e9c22f069740e9eb2db68f6702 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
List<TimeSlice> slices | |
def nd4jDataSet() { | |
val input = Nd4j.zeros(1, 3, slices.length-2) | |
val label = Nd4j.zeros(1, 3, slices.length-2) | |
val red = slices.map[it.predictionRed] | |
val black = slices.map[it.predictionBlack] | |
val diff = slices.map[it.predictionDiff] | |
for(var i = 0; i < slices.length-2; i++) { | |
val currentRed = red.get(i)*10000 | |
val nextRed = red.get(i+1)*10000 | |
val currentBlack = black.get(i)*10000 | |
val nextBlack = black.get(i+1)*10000 | |
val currentDiff = diff.get(i)*10000 | |
val nextDiff = diff.get(i+1)*10000 | |
input.putScalar(#[0, 0, i], currentRed as int) | |
input.putScalar(#[0, 1, i], currentBlack as int) | |
input.putScalar(#[0, 2, i], currentDiff as int) | |
label.putScalar(#[0, 0, i], nextRed as int) | |
label.putScalar(#[0, 1, i], nextBlack as int) | |
label.putScalar(#[0, 2, i], nextDiff as int) | |
} | |
return new DataSet(input, label) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment