Skip to content

Instantly share code, notes, and snippets.

@malzzz
Last active October 2, 2016 04:16
Show Gist options
  • Save malzzz/db44e6be9f0626e155bbd186db6cc471 to your computer and use it in GitHub Desktop.
Save malzzz/db44e6be9f0626e155bbd186db6cc471 to your computer and use it in GitHub Desktop.
Functional 2d convolution, less broken; i.e., works (I think). Probably slow af.
def conv2[A: ClassTag](input: Vector[A], kernel: Vector[A], padding: Int, stride: Int)(implicit ev: Numeric[A]) = {
// Static sizes
val inputHW = sqrt(input.length).toInt
val kernelHW = sqrt(kernel.length).toInt
val rowSizeAfterPad = inputHW + kernelHW - 1
val rowSizeStrided = rowSizeAfterPad + stride
val opsPerPass = rowSizeAfterPad / kernelHW
val passSize = kernelHW * rowSizeStrided
val inputSizePaddedStrided = (rowSizeStrided * inputHW) + (rowSizeStrided * 2)
@tailrec
def buildKernelStream(zip: Int, totalZips: Int, chunkedKernel: Stream[Pure, Vector[A]], acc: Stream[Pure, Vector[A]]): Stream[Pure, A] = {
if (zip < totalZips) {
buildKernelStream(zip + 1, totalZips, chunkedKernel, acc.zipWith(chunkedKernel)(_ ++ _))
} else acc.flatMap(k => Stream.emits(k)).pure
}
@tailrec
def buildStream(pass: Int, totalPasses: Int, rows: Stream[Pure, A], acc: Stream[Pure, A]): Stream[Pure, A] = {
if (pass < totalPasses) {
val newAcc =
acc ++
rows
.drop(pass * rowSizeStrided * stride)
.dropRight(inputSizePaddedStrided - (pass * stride * rowSizeStrided) - passSize)
//print(s"Dropped left: ${pass * rowSizeStrided * stride}, dropped right: ${inputSizePaddedStrided - (pass * stride * rowSizeStrided) - passSize}, row size: $rowSizeStrided\n")
buildStream(pass + 1, totalPasses, rows, newAcc)
} else acc
}
// Build helpers
val padHelper = Stream.constant(ev.zero)
val paddedRow = padHelper.vectorChunkN(rowSizeAfterPad).take(1)
val rowInterleavePadding = Stream.emits(input).vectorChunkN(inputHW).map(vec => ev.zero +: vec :+ ev.zero)
val rowsAfterPadding = paddedRow ++ rowInterleavePadding ++ paddedRow
val rowStream = rowsAfterPadding.map(_.sliding(kernelHW, stride).toVector.flatten).flatMap(row => Stream.emits(row)).pure
val kernelChunked = Stream.emits(kernel).vectorChunkN(kernelHW).pure
// Prepared input (finite) and kernel (infinite) streams
val kernelStream = buildKernelStream(0, (rowSizeStrided / kernelHW) - 1, kernelChunked, kernelChunked)
val dataStream = buildStream(0, rowSizeStrided / kernelHW, rowStream, Stream.empty[Pure, A])
// Zip together, multiply elements, add elements in chunks equal to kernel width, rechunk into kernel-length sizes
dataStream
.zipWith(kernelStream.repeat)((in, kn) => in * kn)
.vectorChunkN(kernelHW)
.map(ev.sum(_))
.vectorChunkN(kernel.length)
.map(_.grouped(kernelHW).toVector.flatMap(_.zipWithIndex).sortWith(_._2 <= _._2).unzip._1.grouped(kernelHW).map(ev.sum(_)).toVector)
.flatMap(Stream.emits)
.pure
//dataStream.zipWith(kernelStream.repeat)((in, kn) => in * kn).vectorChunkN(kernel.length).map(ev.sum(_))
//(dataStream, kernelStream)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment