Last active
December 25, 2015 10:59
-
-
Save hisui/6966090 to your computer and use it in GitHub Desktop.
Software load balancer teki na nanika for testing.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package jp.segfault.minlb | |
import java.nio.channels._ | |
import java.nio.channels.SelectionKey._ | |
import java.nio.ByteBuffer | |
import java.net.{HttpURLConnection, URI, InetSocketAddress} | |
import java.io.IOException | |
import java.util.concurrent.atomic.AtomicBoolean | |
import java.util.concurrent.TimeoutException | |
import scala.util.{Failure, Success, Try} | |
import scala.collection.mutable | |
object Main { | |
// ここにアップストリームの設定を記述する | |
val upstreams:Seq[UpStream] = Seq( | |
UpStream(new InetSocketAddress(9000)) | |
//, UpStream(new InetSocketAddress(9000, Some(new URI("http://localhost:9000/check.html")))) | |
) | |
def main(args:Array[String]) { | |
if (args.isEmpty) { | |
System.err.println("Usage: java -jar lb.jar <port>") | |
System.exit(-1) | |
} | |
val flags = upstreams.map(_ -> new AtomicBoolean(true)) | |
val threads = | |
flags.flatMap { case (ups, flag) => | |
ups.healthCheckURI.map(uri => new Thread(new HealthChecker(uri)(flag.set))) | |
} | |
threads.foreach(_.start()) | |
// アップストリームの選択 | |
var counter = 0 | |
def choose():UpStream = { | |
require(flags.nonEmpty) | |
var n = flags.size | |
while (n > 0) { | |
val (ups, flag) = flags(counter) | |
counter += 1 | |
counter %= flags.size | |
if (flag.get()) { | |
return ups | |
} | |
n -= 1 | |
} | |
println("[warn] No upstream available..") | |
Thread.sleep(300) | |
choose() | |
} | |
try { | |
val server = new Server(args(0).toInt, choose) | |
server.run() | |
} catch { | |
case e:IOException => | |
println("[error] I/O error: "+ e.getMessage) | |
} | |
threads.foreach(_.join()) | |
} | |
} | |
case class UpStream(address:InetSocketAddress, healthCheckURI:Option[URI]=None) | |
class HealthChecker(uri:URI, interval:Int=1000*5)(f: Boolean => Unit) extends Runnable { | |
var running = true | |
def run() { | |
while (running) { | |
val b = alive() | |
println(s"[info] health check: uri='$uri', alive=$b") | |
f(b) | |
try Thread.sleep(interval) catch { case e:InterruptedException => } | |
} | |
} | |
private def alive():Boolean = | |
try { | |
val conn = uri.toURL.openConnection().asInstanceOf[HttpURLConnection] | |
conn.setRequestMethod("GET") | |
conn.setInstanceFollowRedirects(true) | |
conn.setConnectTimeout(1000) | |
conn.connect() | |
conn.getResponseCode == 200 | |
} catch { | |
case e:IOException => false | |
} | |
} | |
case class Timeout(key:SelectionKey, time:Long, id:Int) extends Ordered[Timeout] { | |
def compare(rhs:Timeout):Int = { | |
val delta = (time - rhs.time).toInt | |
if (delta != 0) delta else id - rhs.id | |
} | |
} | |
class SocketManager { | |
var running = true | |
private[this] val selector = Selector.open() | |
private[this] val queue = mutable.SortedSet[Timeout]() | |
private[this] var count = 0 | |
def start() { | |
while (running) { | |
tick( System.currentTimeMillis() ) | |
Thread.sleep(10) | |
} | |
} | |
def tick(time:Long) { | |
if (selector.selectNow() > 0) { | |
selector.selectedKeys().iterator().consume(e => e.handler(Try(e))) | |
} | |
queue.filter(_.time <= time).foreach { e => | |
e.key.cancel() | |
e.key.handler( | |
Failure { new TimeoutException() }) | |
} | |
} | |
def listen(port:Int)(f: SocketChannel => Unit) { | |
val server = ServerSocketChannel.open() | |
server.configureBlocking(false) | |
server.register(OP_ACCEPT, 0) { _ => | |
f(server.accept()) | |
} | |
server.bind(new InetSocketAddress(port)) | |
} | |
def connect(address:InetSocketAddress, timeout:Int=1000)(f: Try[SocketChannel] => Unit) { | |
val socket = SocketChannel.open() | |
socket.configureBlocking(false) | |
socket.register(OP_CONNECT, timeout) { e => | |
f(e.map { _ => | |
if (!socket.finishConnect()) { | |
throw new IllegalStateException("(*_*) bug?") | |
} | |
socket | |
}) | |
} | |
socket.connect(address) | |
} | |
def watch(channel:SocketChannel)( | |
reader: => Unit, | |
writer: => Unit, | |
// wantToWrite: => Boolean, | |
wantToClose: => Boolean, | |
timeoutHandler: => Unit = (), | |
timeout:Int = 0) | |
{ | |
def handler(e:Try[SelectionKey]):Unit = e match { | |
case Success(key) => | |
if (key.isReadable) reader | |
if (key.isWritable) writer | |
if (wantToClose) { | |
channel.close() | |
return | |
} | |
// `OP_WRITE` で登録してしまうと、書き込む気がなくても毎回ハンドラーが呼びだされてしまう | |
channel.register(OP_READ | OP_WRITE, timeout)(handler) | |
case Failure(e:TimeoutException) => timeoutHandler | |
case Failure(e) => | |
throw new IllegalStateException("(*_*) bug?", e) | |
} | |
channel.configureBlocking(false) | |
channel.register(OP_READ | OP_WRITE, timeout)(handler) | |
} | |
implicit class SelectionKeyOps(raw:SelectionKey) { | |
def handler = raw.attachment().asInstanceOf[Try[SelectionKey] => Unit] | |
} | |
implicit class SelectableChannelOps(raw:SelectableChannel) { | |
def register(ops:Int, timeout:Int)(f: Try[SelectionKey] => Unit) { | |
if (timeout < 1) { | |
raw.register(selector, ops, f) | |
return | |
} | |
var entry:Timeout = null | |
val key = raw.register(selector, ops, { e:Try[SelectionKey] => | |
if (queue.remove(entry)) f(e) | |
else { | |
throw new IllegalStateException("(*_*) bug?") | |
} | |
}) | |
entry = Timeout(key, timeout + System.currentTimeMillis(), count) | |
count += 1 | |
queue.add(entry) | |
} | |
} | |
implicit class JavaIteratorOps[T](raw:java.util.Iterator[T]) { | |
def consume(f:T => Unit) { | |
while (raw.hasNext) try f(raw.next()) finally raw.remove() | |
} | |
} | |
} | |
class Server(port:Int, upstream: => UpStream) extends Runnable { | |
private[this] val manager = new SocketManager | |
// これより長い間データが流れなかった通信はタイムアウト | |
val Timeout = 300 | |
def run() { | |
manager.listen(port) { a => | |
manager.connect(upstream.address) { | |
case Failure(e) => a.close() | |
case Success(b) => | |
def onTimedOut { | |
println("[warn] Connection timed out..") | |
if (a.isConnected) a.close() | |
if (b.isConnected) b.close() | |
} | |
println("[info] New edge: "+ a.getRemoteAddress +" -> "+ b.getRemoteAddress) | |
val a2b = new Flows(a, b) | |
val b2a = new Flows(b, a) | |
def wantToClose = a2b.shutdown || b2a.shutdown | |
manager.watch(a)(a2b.read(), b2a.write(), wantToClose, onTimedOut, Timeout) | |
manager.watch(b)(b2a.read(), a2b.write(), wantToClose, onTimedOut, Timeout) | |
} | |
} | |
manager.start() | |
} | |
class Flows(reader:SocketChannel, writer:SocketChannel) { | |
private[this] val buf = ByteBuffer.allocateDirect(4096) | |
private[this] var readable = false | |
private[this] var writable = false | |
var shutdown = false | |
def proceed() { | |
if (!(writer.isOpen && reader.isOpen)) { | |
shutdown = true | |
return | |
} | |
while (readable || writable) { | |
if (!shutdown && readable && buf.remaining() > 0) { | |
val n = reader.read(buf) | |
if (n >= 0) readable = buf.remaining() == 0 | |
else { | |
readable = false | |
shutdown = true | |
} | |
if (!writable) return | |
} | |
if (writable) { | |
if (buf.flip().remaining() > 0) { | |
writer.write(buf) | |
writable = buf.remaining() == 0 | |
} | |
buf.compact() | |
if (!readable) return | |
} | |
} | |
} | |
def read() { | |
readable = true | |
proceed() | |
} | |
def write() { | |
writable = true | |
proceed() | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment