Created
December 5, 2014 21:16
-
-
Save torao/0cb8f1e533ad99ed7651 to your computer and use it in GitHub Desktop.
マルチスレッドから DatagramChannel で送信する場合の同期/非同期 I/O パフォーマンス比較
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.net.InetSocketAddress | |
import java.nio.ByteBuffer | |
import java.nio.channels.{ClosedSelectorException,DatagramChannel,SelectionKey,Selector} | |
import java.util.concurrent.atomic.{AtomicLong,LongAdder} | |
import java.util.concurrent.{LinkedBlockingQueue,BlockingQueue,Executors,TimeUnit} | |
class Competition(threadCount:Int, payloadSize:Int, directBuffer:Boolean) { | |
val buffer = if(directBuffer) ByteBuffer.allocateDirect(payloadSize) else ByteBuffer.allocate(payloadSize) | |
val channel = DatagramChannel.open() | |
var roundStart = System.nanoTime() | |
val sendCallCount = new LongAdder() | |
val channelSendCount = new LongAdder() | |
val lostCount = new LongAdder() | |
var continue = true | |
def parallel(r: =>Unit):Unit = { | |
val threads = (0 until threadCount).map { i => | |
new Thread(s"Worker-$i"){ | |
override def run() = while(continue){ r } | |
} | |
} | |
threads.foreach { _.start() } | |
threads.foreach { _.join() } | |
} | |
def wraptime():(Long,Long,Long,Long) = { | |
val scc = sendCallCount.longValue() | |
val csc = channelSendCount.longValue() | |
val lc = lostCount.longValue() | |
val tm = System.nanoTime() | |
val it = tm - roundStart | |
roundStart = tm | |
sendCallCount.reset() | |
channelSendCount.reset() | |
lostCount.reset() | |
(it, scc, csc, lc) | |
} | |
def close() = { | |
continue = false | |
Thread.sleep(500) | |
channel.close() | |
} | |
} | |
object Competition { | |
val Destination = new InetSocketAddress("localhost", 11999) | |
def sync(c:Competition):Unit = c.parallel { | |
c.channel.send(c.buffer, Competition.Destination) | |
c.sendCallCount.increment() | |
c.channelSendCount.increment() | |
} | |
def async(c:Competition):Unit = { | |
val sync = new Object() | |
val selector = Selector.open() | |
c.channel.configureBlocking(false) | |
c.channel.register(selector, SelectionKey.OP_WRITE, new LinkedBlockingQueue[ByteBuffer](1024)) | |
new Thread("AsyncPump"){ | |
override def run() = try { | |
while(true){ | |
selector.select() | |
val keys = selector.selectedKeys() | |
val it = keys.iterator() | |
while(it.hasNext()){ | |
val key = it.next() | |
keys.remove(key) | |
if(key.isWritable){ | |
val queue = key.attachment().asInstanceOf[BlockingQueue[ByteBuffer]] | |
val payload = queue.poll(0, TimeUnit.SECONDS) | |
if(payload != null){ | |
key.channel().asInstanceOf[DatagramChannel].send(payload, Competition.Destination) | |
c.channelSendCount.increment() | |
} else sync.synchronized { | |
if(queue.size() == 0){ | |
key.interestOps(0) | |
} | |
} | |
} | |
} | |
} | |
} catch { | |
case ex:ClosedSelectorException => None | |
case ex:Exception => ex.printStackTrace() | |
} | |
}.start() | |
val key = c.channel.keyFor(selector) | |
val queue = key.attachment().asInstanceOf[BlockingQueue[ByteBuffer]] | |
c.parallel { | |
val success = sync.synchronized{ | |
if(queue.offer(c.buffer)){ | |
if(queue.size() == 1){ | |
key.interestOps(SelectionKey.OP_WRITE) | |
} | |
true | |
} else false | |
} | |
c.sendCallCount.increment() | |
if(! success){ | |
c.lostCount.increment() | |
} | |
} | |
selector.close() | |
} | |
def main(args:Array[String]):Unit = { | |
val seconds = 10 | |
val runner:(Competition)=>Unit = if(args(0) == "async") async else sync | |
val threadCount = args(1).toInt | |
val payloadSize = args(2).toInt | |
val directBuffer = (args.length >= 4 && args(3) == "direct") | |
val c = new Competition(threadCount, payloadSize, directBuffer) | |
new Thread(){ override def run() = runner(c) }.start() | |
Thread.sleep(3 * 1000) | |
c.wraptime() | |
Thread.sleep(seconds * 1000) | |
val time = c.wraptime() | |
c.close() | |
val sendCallPerSec = time._2 * (1000.0 * 1000.0 * 1000.0) / time._1 / seconds | |
val sendPerSec = time._3 * (1000.0*1000.0*1000.0) / time._1 / seconds | |
val failPerSec = time._4 * (1000.0*1000.0*1000.0) / time._1 / seconds | |
System.out.println(f"$sendCallPerSec%.2f,$sendPerSec%.2f,$failPerSec%.2f") | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment