Created
November 17, 2015 07:38
-
-
Save kylewlacy/38681caaee2b98949dd8 to your computer and use it in GitHub Desktop.
Cartesian product function in Scala
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scala.reflect.ClassTag | |
object CartesianProduct { | |
/** | |
* Given an array containing a partial Cartesian product, and an array of | |
* items, return an array adding the list of items to the partial | |
* Cartesian product. | |
* | |
* {{{ | |
* val partialProduct = Array(Array(1, 4), Array(1, 5), Array(2, 4), Array(2, 5)) | |
* val items = Array(6, 7) | |
* partialCartesian(partialProduct, items) == | |
* Array(Array(1, 4, 6), | |
* Array(1, 4, 7), | |
* Array(1, 5, 6), | |
* Array(1, 5, 7), | |
* Array(2, 4, 6), | |
* Array(2, 4, 7), | |
* Array(2, 5, 6), | |
* Array(2, 5, 7)) | |
* }}} | |
*/ | |
private def partialCartesian[T: ClassTag](a: Array[Array[T]], b: Array[T]): | |
Array[Array[T]] = { | |
a.flatMap(xs => { | |
b.map(y => { | |
xs ++ Array(y) | |
}) | |
}) | |
} | |
/** | |
* Computes the Cartesian product of lists[0] * lists[1] * ... * lists[n]. | |
* | |
* {{{ | |
* scala> import CartesianProduct._ | |
* scala> val lists = Array(Array(1, 2), Array(4, 5), Array(6, 7)); | |
* scala> cartesianProduct(lists) | |
* scala> cartesianProduct(lists) | |
* res0: Array[Array[Int]] = Array(Array(1, 4, 6), | |
* Array(1, 4, 7), | |
* Array(1, 5, 6), | |
* Array(1, 5, 7), | |
* Array(2, 4, 6), | |
* Array(2, 4, 7), | |
* Array(2, 5, 6), | |
* Array(2, 5, 7)) | |
* }}} | |
*/ | |
def cartesianProduct[T: ClassTag](lists: Array[Array[T]]): Array[Array[T]] = { | |
lists.headOption match { | |
case Some(head) => { | |
val tail = lists.tail | |
val init = head.map(n => Array(n)) | |
tail.foldLeft(init)((arr, list) => { | |
partialCartesian(arr, list) | |
}) | |
} | |
case None => { | |
Array() | |
} | |
} | |
} | |
/** Print the Cartesian product of a set of lists to stdout, in | |
* the following form: | |
* | |
* {{{ | |
* 1 4 6 | |
* 1 4 7 | |
* 1 5 6 | |
* 1 5 7 | |
* 2 4 6 | |
* 2 4 7 | |
* ... | |
* }}} | |
*/ | |
private def printCartesianProduct(lists: Array[Array[Int]]) { | |
val products = cartesianProduct(lists) | |
products.foreach(product => { | |
println(product.mkString(" ")) | |
}) | |
} | |
def main(args: Array[String]) { | |
val a = Array(1, 2, 3) | |
val b = Array(4, 5) | |
val c = Array(6, 7, 8) | |
printCartesianProduct(Array(a, b, c)) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This helps a lot, many thanks, it costs me an afternoon to figure out how to get it done. πππ