(注)この記事はScala関数型プログラミング&デザイン7章前半の劣化版まとめです
整数列の足し算について考えてみよう。たたみ込みなどがパッと思い付くだろう。
def sum(ints: Seq[Int]): Int = ints.foldLeft(0)((a, b) => a + b)
// 注: scalaでは整数列のたたみ込み演算についてあらかじめ関数が用意されていて、
// IntelliJを使っている場合は上記のコードに対して以下のsum関数を用いることを勧められるかもしれない
// at scala.TraversableOnce
def sum[B >: A](implicit num: Numeric[B]): B = foldLeft(num.zero)(num.plus)
さて上記のたたみ込みだと、列の端から順に足し算を繰り返すことになり、計算を並列化することはできない。 容易に分割できる整数列であるならば、分割したそれぞれのパーツを並列に計算して最後に足し合わせることができそうである。 そこで上のコードを少し改良してみよう。
def sum(ints: IndexedSeq[Int]): Int
if (ints.size <= 1) ints.headOption.getOrElse(0)
else {
val (l, r) = ints.splitAt(ints.length / 2)
sum(l) + sum(r)
}
さて、Scalaでは式は正格評価されるため sum(l) + sum(r)
の部分について、
左側の sum(l)
の評価が終わってから右側の sum(r)
の評価が始まる。
評価を遅らせるためには単純にthunkを作ればよい。
つまり、() => sum(l)
, () => sum(r)
としてやればよい。
def sum(ints: IndexedSeq[Int]): Int =
if (ints.size <= 1) ints.headOption.getOrElse(0)
else {
val (l, r) = ints.splitAt(ints.size / 2)
val sumL = () => sum(l)
val sumR = () => sum(r)
sumL() + sumR()
}
しかしこのコードも sumL()
の評価が終わってから、sumR()
の評価が始まるため何も解決されていません。
val sumL = () => sum(l)
, val sumR = () => sum(r)
の部分で別スレッドで評価を始める計算を得て、
sumL() + sumR()
の部分で両スレッドの計算を待ち、値を返すようになれば並列化ができそうである。
そこでそういった機能を持つ関数を実装は置いておいて、とりあえずインターフェースだけ定義しておきましょう。
trait Par[A] {
}
object Par {
// 未評価なA型の式を受け取り、別スレッドで評価するための計算を返す
def unit[A](a: => A): Par[A] = ???
// 並列計算結果を取り出す
def get[A](a: Par[A]): A = ???
}
これを用いると以下のように整数列の足し算を書き直すことができる。
def sum(ints: IndexedSeq[Int]): Int =
if (ints.size <= 1) ints.headOption.getOrElse(0)
else {
val (l, r) = ints.splitAt(ints.size / 2)
val sumL: Par[Int] = Par.unit(sum(l))
val sumR: Par[Int] = Par.unit(sum(r))
Par.get(sumL) + Par.get(sumR)
}
さて unit
は引数を受け取った瞬間、その評価を別のスレッドで直ちに開始する実装だとした場合、
確かにsum関数は並列化することを達成できる。
しかしながら、Par.get(sumL) + Par.get(sumR)
をインライン展開したとすれば、Par.get(Par.unit(sum(l))) + Par.get(Par.unit(sum(r)))
となり並列性が失われる。なぜならば、get
は Par[Int] の計算が終わるまで待機するからである。つまり、unit
は get
に対し副作用を持っているということになる。
ということは、sum関数からPar[Int]を直接返してしまえば問題ないのである。Par.get(sumL) + Par.get(sumR)
としている部分は sumL
と sumR
を合成したPar[Int]�値を返せば良い。そしてこの時点でunitは引数の評価を非正格にする必要がなくなる。
object Par {
def unit[A](a: A): Par[A] = ???
def map2[A, B, C](a: Par[A], b: Par[B])(f: (A, B) => C): Par[C] = ???
def sum(ints: IndexedSeq[Int]): Par[Int] =
if (ints.size <= 1) Par.unit(ints.headOption.getOrElse(0))
else {
val (l, r) = ints.splitAt(ints.size / 2)
Par.map2(sum(l), sum(r))(_ + _)
}
}
さて、このようにしてできたsum関数に IndexedSeq(1, 2, 3, 4)
を渡してみると、 Par.map2(sum(l), sum(r))(_ + _)
の部分で左側の引数 sum(l)
のほうが先に展開されてしまうという問題がある。したがってmap2関数の引数評価を遅らせる必要があるように思える。ただ Par.map2(Par.unit(1), Par.unit(2))(_ + _)
のような単純な計算については即座に引数を評価したい。そこで、明示的に別スレッドで実行すべきであるという意味をもたせたfork関数を導入しよう。
object Par {
def unit[A](a: A): Par[A] = ???
def fork[A](a: => Par[A]): Par[A] = ???
def map2[A, B, C](a: Par[A], b: Par[B])(f: (A, B) => C): Par[C] = ???
def sum(ints: IndexedSeq[Int]): Par[Int] =
if (ints.size <= 1) Par.unit(ints.headOption.getOrElse(0))
else {
val (l, r) = ints.splitAt(ints.size / 2)
Par.map2(fork(sum(l)), fork(sum(r)))(_ + _)
}
}
こうすることにより forkで包まれた sum
関数は直ちに評価されないため、sum(l)
と sum(r)
は同時に計算が開始される。
さてこれを java.concurrent.ExecutorService
を用いて実装してみよう。以下のようになるだろう。
import java.util.concurrent.{Callable, TimeUnit, Future, ExecutorService}
object Par {
type Par[A] = ExecutorService => Future[A]
/*
* primitives
*/
def unit[A](a: A): Par[A] = (es: ExecutorService) => UnitFuture(a)
private case class UnitFuture[A](get: A) extends Future[A] {
override def isDone: Boolean = true
override def get(timeout: Long, units: TimeUnit): A = get
override def isCancelled: Boolean = false
override def cancel(mayInterruptIfRunning: Boolean): Boolean = false
}
def fork[A](a: => Par[A]): Par[A] =
(es: ExecutorService) => {
es.submit(new Callable[A] {
override def call(): A = a(es).get
})
}
def map2[A, B, C](a: Par[A], b: Par[B])(f: (A, B) => C): Par[C] =
(es: ExecutorService) => {
val af = a(es)
val bf = b(es)
UnitFuture(f(af.get, bf.get))
}
def flatMap[A, B](pa: Par[A])(f: A => Par[B]): Par[B] =
(es: ExecutorService) => {
val a: A = run(es)(pa).get
run(es)(f(a))
}
def run[A](es: ExecutorService)(a: Par[A]): Future[A] = a(es)
/*
* derivative-combinators
*/
def lazyUnit[A](a: => A): Par[A] = fork(unit(a))
def asyncF[A, B](f: A => B): A => Par[B] = a => lazyUnit(f(a))
def map[A, B](par: Par[A])(f: A => B): Par[B] =
map2(par, unit(()))((a, _) => f(a))
def sum(ints: IndexedSeq[Int]): Par[Int] =
if (ints.size <= 1) Par.unit(ints.headOption.getOrElse(0))
else {
val (l, r) = ints.splitAt(ints.size / 2)
Par.map2(fork(sum(l)), fork(sum(r)))(_ + _)
}
}