Last active
October 2, 2016 04:16
-
-
Save malzzz/db44e6be9f0626e155bbd186db6cc471 to your computer and use it in GitHub Desktop.
Functional 2d convolution, less broken; i.e., works (I think). Probably slow af.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def conv2[A: ClassTag](input: Vector[A], kernel: Vector[A], padding: Int, stride: Int)(implicit ev: Numeric[A]) = { | |
// Static sizes | |
val inputHW = sqrt(input.length).toInt | |
val kernelHW = sqrt(kernel.length).toInt | |
val rowSizeAfterPad = inputHW + kernelHW - 1 | |
val rowSizeStrided = rowSizeAfterPad + stride | |
val opsPerPass = rowSizeAfterPad / kernelHW | |
val passSize = kernelHW * rowSizeStrided | |
val inputSizePaddedStrided = (rowSizeStrided * inputHW) + (rowSizeStrided * 2) | |
@tailrec | |
def buildKernelStream(zip: Int, totalZips: Int, chunkedKernel: Stream[Pure, Vector[A]], acc: Stream[Pure, Vector[A]]): Stream[Pure, A] = { | |
if (zip < totalZips) { | |
buildKernelStream(zip + 1, totalZips, chunkedKernel, acc.zipWith(chunkedKernel)(_ ++ _)) | |
} else acc.flatMap(k => Stream.emits(k)).pure | |
} | |
@tailrec | |
def buildStream(pass: Int, totalPasses: Int, rows: Stream[Pure, A], acc: Stream[Pure, A]): Stream[Pure, A] = { | |
if (pass < totalPasses) { | |
val newAcc = | |
acc ++ | |
rows | |
.drop(pass * rowSizeStrided * stride) | |
.dropRight(inputSizePaddedStrided - (pass * stride * rowSizeStrided) - passSize) | |
//print(s"Dropped left: ${pass * rowSizeStrided * stride}, dropped right: ${inputSizePaddedStrided - (pass * stride * rowSizeStrided) - passSize}, row size: $rowSizeStrided\n") | |
buildStream(pass + 1, totalPasses, rows, newAcc) | |
} else acc | |
} | |
// Build helpers | |
val padHelper = Stream.constant(ev.zero) | |
val paddedRow = padHelper.vectorChunkN(rowSizeAfterPad).take(1) | |
val rowInterleavePadding = Stream.emits(input).vectorChunkN(inputHW).map(vec => ev.zero +: vec :+ ev.zero) | |
val rowsAfterPadding = paddedRow ++ rowInterleavePadding ++ paddedRow | |
val rowStream = rowsAfterPadding.map(_.sliding(kernelHW, stride).toVector.flatten).flatMap(row => Stream.emits(row)).pure | |
val kernelChunked = Stream.emits(kernel).vectorChunkN(kernelHW).pure | |
// Prepared input (finite) and kernel (infinite) streams | |
val kernelStream = buildKernelStream(0, (rowSizeStrided / kernelHW) - 1, kernelChunked, kernelChunked) | |
val dataStream = buildStream(0, rowSizeStrided / kernelHW, rowStream, Stream.empty[Pure, A]) | |
// Zip together, multiply elements, add elements in chunks equal to kernel width, rechunk into kernel-length sizes | |
dataStream | |
.zipWith(kernelStream.repeat)((in, kn) => in * kn) | |
.vectorChunkN(kernelHW) | |
.map(ev.sum(_)) | |
.vectorChunkN(kernel.length) | |
.map(_.grouped(kernelHW).toVector.flatMap(_.zipWithIndex).sortWith(_._2 <= _._2).unzip._1.grouped(kernelHW).map(ev.sum(_)).toVector) | |
.flatMap(Stream.emits) | |
.pure | |
//dataStream.zipWith(kernelStream.repeat)((in, kn) => in * kn).vectorChunkN(kernel.length).map(ev.sum(_)) | |
//(dataStream, kernelStream) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment