Skip to content

Instantly share code, notes, and snippets.

@hisui
Last active December 25, 2015 10:59
Show Gist options
  • Save hisui/6966090 to your computer and use it in GitHub Desktop.
Save hisui/6966090 to your computer and use it in GitHub Desktop.
Software load balancer teki na nanika for testing.
package jp.segfault.minlb
import java.nio.channels._
import java.nio.channels.SelectionKey._
import java.nio.ByteBuffer
import java.net.{HttpURLConnection, URI, InetSocketAddress}
import java.io.IOException
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.TimeoutException
import scala.util.{Failure, Success, Try}
import scala.collection.mutable
object Main {
// ここにアップストリームの設定を記述する
val upstreams:Seq[UpStream] = Seq(
UpStream(new InetSocketAddress(9000))
//, UpStream(new InetSocketAddress(9000, Some(new URI("http://localhost:9000/check.html"))))
)
def main(args:Array[String]) {
if (args.isEmpty) {
System.err.println("Usage: java -jar lb.jar <port>")
System.exit(-1)
}
val flags = upstreams.map(_ -> new AtomicBoolean(true))
val threads =
flags.flatMap { case (ups, flag) =>
ups.healthCheckURI.map(uri => new Thread(new HealthChecker(uri)(flag.set)))
}
threads.foreach(_.start())
// アップストリームの選択
var counter = 0
def choose():UpStream = {
require(flags.nonEmpty)
var n = flags.size
while (n > 0) {
val (ups, flag) = flags(counter)
counter += 1
counter %= flags.size
if (flag.get()) {
return ups
}
n -= 1
}
println("[warn] No upstream available..")
Thread.sleep(300)
choose()
}
try {
val server = new Server(args(0).toInt, choose)
server.run()
} catch {
case e:IOException =>
println("[error] I/O error: "+ e.getMessage)
}
threads.foreach(_.join())
}
}
case class UpStream(address:InetSocketAddress, healthCheckURI:Option[URI]=None)
class HealthChecker(uri:URI, interval:Int=1000*5)(f: Boolean => Unit) extends Runnable {
var running = true
def run() {
while (running) {
val b = alive()
println(s"[info] health check: uri='$uri', alive=$b")
f(b)
try Thread.sleep(interval) catch { case e:InterruptedException => }
}
}
private def alive():Boolean =
try {
val conn = uri.toURL.openConnection().asInstanceOf[HttpURLConnection]
conn.setRequestMethod("GET")
conn.setInstanceFollowRedirects(true)
conn.setConnectTimeout(1000)
conn.connect()
conn.getResponseCode == 200
} catch {
case e:IOException => false
}
}
case class Timeout(key:SelectionKey, time:Long, id:Int) extends Ordered[Timeout] {
def compare(rhs:Timeout):Int = {
val delta = (time - rhs.time).toInt
if (delta != 0) delta else id - rhs.id
}
}
class SocketManager {
var running = true
private[this] val selector = Selector.open()
private[this] val queue = mutable.SortedSet[Timeout]()
private[this] var count = 0
def start() {
while (running) {
tick( System.currentTimeMillis() )
Thread.sleep(10)
}
}
def tick(time:Long) {
if (selector.selectNow() > 0) {
selector.selectedKeys().iterator().consume(e => e.handler(Try(e)))
}
queue.filter(_.time <= time).foreach { e =>
e.key.cancel()
e.key.handler(
Failure { new TimeoutException() })
}
}
def listen(port:Int)(f: SocketChannel => Unit) {
val server = ServerSocketChannel.open()
server.configureBlocking(false)
server.register(OP_ACCEPT, 0) { _ =>
f(server.accept())
}
server.bind(new InetSocketAddress(port))
}
def connect(address:InetSocketAddress, timeout:Int=1000)(f: Try[SocketChannel] => Unit) {
val socket = SocketChannel.open()
socket.configureBlocking(false)
socket.register(OP_CONNECT, timeout) { e =>
f(e.map { _ =>
if (!socket.finishConnect()) {
throw new IllegalStateException("(*_*) bug?")
}
socket
})
}
socket.connect(address)
}
def watch(channel:SocketChannel)(
reader: => Unit,
writer: => Unit,
// wantToWrite: => Boolean,
wantToClose: => Boolean,
timeoutHandler: => Unit = (),
timeout:Int = 0)
{
def handler(e:Try[SelectionKey]):Unit = e match {
case Success(key) =>
if (key.isReadable) reader
if (key.isWritable) writer
if (wantToClose) {
channel.close()
return
}
// `OP_WRITE` で登録してしまうと、書き込む気がなくても毎回ハンドラーが呼びだされてしまう
channel.register(OP_READ | OP_WRITE, timeout)(handler)
case Failure(e:TimeoutException) => timeoutHandler
case Failure(e) =>
throw new IllegalStateException("(*_*) bug?", e)
}
channel.configureBlocking(false)
channel.register(OP_READ | OP_WRITE, timeout)(handler)
}
implicit class SelectionKeyOps(raw:SelectionKey) {
def handler = raw.attachment().asInstanceOf[Try[SelectionKey] => Unit]
}
implicit class SelectableChannelOps(raw:SelectableChannel) {
def register(ops:Int, timeout:Int)(f: Try[SelectionKey] => Unit) {
if (timeout < 1) {
raw.register(selector, ops, f)
return
}
var entry:Timeout = null
val key = raw.register(selector, ops, { e:Try[SelectionKey] =>
if (queue.remove(entry)) f(e)
else {
throw new IllegalStateException("(*_*) bug?")
}
})
entry = Timeout(key, timeout + System.currentTimeMillis(), count)
count += 1
queue.add(entry)
}
}
implicit class JavaIteratorOps[T](raw:java.util.Iterator[T]) {
def consume(f:T => Unit) {
while (raw.hasNext) try f(raw.next()) finally raw.remove()
}
}
}
class Server(port:Int, upstream: => UpStream) extends Runnable {
private[this] val manager = new SocketManager
// これより長い間データが流れなかった通信はタイムアウト
val Timeout = 300
def run() {
manager.listen(port) { a =>
manager.connect(upstream.address) {
case Failure(e) => a.close()
case Success(b) =>
def onTimedOut {
println("[warn] Connection timed out..")
if (a.isConnected) a.close()
if (b.isConnected) b.close()
}
println("[info] New edge: "+ a.getRemoteAddress +" -> "+ b.getRemoteAddress)
val a2b = new Flows(a, b)
val b2a = new Flows(b, a)
def wantToClose = a2b.shutdown || b2a.shutdown
manager.watch(a)(a2b.read(), b2a.write(), wantToClose, onTimedOut, Timeout)
manager.watch(b)(b2a.read(), a2b.write(), wantToClose, onTimedOut, Timeout)
}
}
manager.start()
}
class Flows(reader:SocketChannel, writer:SocketChannel) {
private[this] val buf = ByteBuffer.allocateDirect(4096)
private[this] var readable = false
private[this] var writable = false
var shutdown = false
def proceed() {
if (!(writer.isOpen && reader.isOpen)) {
shutdown = true
return
}
while (readable || writable) {
if (!shutdown && readable && buf.remaining() > 0) {
val n = reader.read(buf)
if (n >= 0) readable = buf.remaining() == 0
else {
readable = false
shutdown = true
}
if (!writable) return
}
if (writable) {
if (buf.flip().remaining() > 0) {
writer.write(buf)
writable = buf.remaining() == 0
}
buf.compact()
if (!readable) return
}
}
}
def read() {
readable = true
proceed()
}
def write() {
writable = true
proceed()
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment