Created
February 13, 2012 19:42
-
-
Save casualjim/1819496 to your computer and use it in GitHub Desktop.
A Netty based WebSocket client and server in scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package mojolly.io | |
import org.jboss.netty.bootstrap.ClientBootstrap | |
import org.jboss.netty.channel._ | |
import socket.nio.NioClientSocketChannelFactory | |
import java.util.concurrent.Executors | |
import org.jboss.netty.handler.codec.http._ | |
import collection.JavaConversions._ | |
import websocketx._ | |
import java.net.{InetSocketAddress, URI} | |
import java.nio.charset.Charset | |
import org.jboss.netty.buffer.ChannelBuffers | |
import org.jboss.netty.util.CharsetUtil | |
import akka.actor.ActorRef | |
import mojolly.LibraryConstants | |
/** | |
* Usage of the simple websocket client: | |
* <pre> | |
* WebSocketClient(new URI("ws://localhost:8080/thesocket")) { | |
* case Connected(client) => println("Connection has been established to: " + client.url.toASCIIString) | |
* case Disconnected(client, _) => println("The websocket to " + client.url.toASCIIString + " disconnected.") | |
* case TextMessage(client, message) => { | |
* println("RECV: " + message) | |
* client send ("ECHO: " + message) | |
* } | |
* } | |
* </pre> | |
*/ | |
object WebSocketClient { | |
object Messages { | |
sealed trait WebSocketClientMessage | |
case object Connecting extends WebSocketClientMessage | |
case class ConnectionFailed(client: WebSocketClient, reason: Option[Throwable] = None) extends WebSocketClientMessage | |
case class Connected(client: WebSocketClient) extends WebSocketClientMessage | |
case class TextMessage(client: WebSocketClient, text: String) extends WebSocketClientMessage | |
case class WriteFailed(client: WebSocketClient, message: String, reason: Option[Throwable]) extends WebSocketClientMessage | |
case object Disconnecting extends WebSocketClientMessage | |
case class Disconnected(client: WebSocketClient, reason: Option[Throwable] = None) extends WebSocketClientMessage | |
case class Error(client: WebSocketClient, th: Throwable) extends WebSocketClientMessage | |
} | |
type Handler = PartialFunction[Messages.WebSocketClientMessage, Unit] | |
type FrameReader = WebSocketFrame => String | |
val defaultFrameReader = (_: WebSocketFrame) match { | |
case f: TextWebSocketFrame => f.getText | |
case _ => throw new UnsupportedOperationException("Only single text frames are supported for now") | |
} | |
def apply(url: URI, version: WebSocketVersion = WebSocketVersion.V13, reader: FrameReader = defaultFrameReader)(handle: Handler): WebSocketClient = { | |
require(url.getScheme.startsWith("ws"), "The scheme of the url should be 'ws' or 'wss'") | |
new DefaultWebSocketClient(url, version, handle, reader) | |
} | |
def apply(url: URI, handle: ActorRef): WebSocketClient = { | |
require(url.getScheme.startsWith("ws"), "The scheme of the url should be 'ws' or 'wss'") | |
WebSocketClient(url) { case x => handle ! x } | |
} | |
private class WebSocketClientHandler(handshaker: WebSocketClientHandshaker, client: WebSocketClient) extends SimpleChannelUpstreamHandler { | |
import Messages._ | |
override def channelClosed(ctx: ChannelHandlerContext, e: ChannelStateEvent) { | |
client.handler(Disconnected(client)) | |
} | |
override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent) { | |
e.getMessage match { | |
case resp: HttpResponse if handshaker.isHandshakeComplete => | |
throw new WebSocketException("Unexpected HttpResponse (status=" + resp.getStatus + ", content=" | |
+ resp.getContent.toString(CharsetUtil.UTF_8) + ")") | |
case resp: HttpResponse => | |
handshaker.finishHandshake(ctx.getChannel, e.getMessage.asInstanceOf[HttpResponse]) | |
client.handler(Connected(client)) | |
case f: TextWebSocketFrame => client.handler(TextMessage(client, f.getText)) | |
case _: PongWebSocketFrame => | |
case _: CloseWebSocketFrame => ctx.getChannel.close() | |
} | |
} | |
override def exceptionCaught(ctx: ChannelHandlerContext, e: ExceptionEvent) { | |
client.handler(Error(client, e.getCause)) | |
e.getChannel.close() | |
} | |
} | |
private class DefaultWebSocketClient( | |
val url: URI, | |
version: WebSocketVersion, | |
private[this] val _handler: Handler, | |
val reader: FrameReader = defaultFrameReader) extends WebSocketClient { | |
val normalized = url.normalize() | |
val tgt = if (normalized.getPath == null || normalized.getPath.trim().isEmpty) { | |
new URI(normalized.getScheme, normalized.getAuthority,"/", normalized.getQuery, normalized.getFragment) | |
} else normalized | |
val bootstrap = new ClientBootstrap(new NioClientSocketChannelFactory(Executors.newCachedThreadPool, Executors.newCachedThreadPool)) | |
val handshaker = new WebSocketClientHandshakerFactory().newHandshaker(tgt, version, null, false, Map.empty[String, String]) | |
val self = this | |
var channel: Channel = _ | |
import Messages._ | |
val handler = _handler orElse defaultHandler | |
private def defaultHandler: Handler = { | |
case Error(_, ex) => ex.printStackTrace() | |
case _: WebSocketClientMessage => | |
} | |
bootstrap.setPipelineFactory(new ChannelPipelineFactory { | |
def getPipeline = { | |
val pipeline = Channels.pipeline() | |
if (version == WebSocketVersion.V00) | |
pipeline.addLast("decoder", new WebSocketHttpResponseDecoder) | |
else | |
pipeline.addLast("decoder", new HttpResponseDecoder) | |
pipeline.addLast("encoder", new HttpRequestEncoder) | |
pipeline.addLast("ws-handler", new WebSocketClientHandler(handshaker, self)) | |
pipeline | |
} | |
}) | |
import WebSocketClient.Messages._ | |
def connect = { | |
if (channel == null || !channel.isConnected) { | |
val listener = futureListener { future => | |
if (future.isSuccess) { | |
synchronized { channel = future.getChannel } | |
handshaker.handshake(channel) | |
} else { | |
handler(ConnectionFailed(this, Option(future.getCause))) | |
} | |
} | |
handler(Connecting) | |
val fut = bootstrap.connect(new InetSocketAddress(url.getHost, url.getPort)) | |
fut.addListener(listener) | |
fut.await(5000L) | |
} | |
} | |
def disconnect = { | |
if (channel != null && channel.isConnected) { | |
handler(Disconnecting) | |
channel.write(new CloseWebSocketFrame()) | |
} | |
} | |
def send(message: String, charset: Charset = CharsetUtil.UTF_8) = { | |
channel.write(new TextWebSocketFrame(ChannelBuffers.copiedBuffer(message, charset))).addListener(futureListener { fut => | |
if (!fut.isSuccess) { | |
handler(WriteFailed(this, message, Option(fut.getCause))) | |
} | |
}) | |
} | |
def futureListener(handleWith: ChannelFuture => Unit) = new ChannelFutureListener { | |
def operationComplete(future: ChannelFuture) {handleWith(future)} | |
} | |
} | |
/** | |
* Fix bug in standard HttpResponseDecoder for web socket clients. When status 101 is received for Hybi00, there are 16 | |
* bytes of contents expected | |
*/ | |
class WebSocketHttpResponseDecoder extends HttpResponseDecoder { | |
val codes = List(101, 200, 204, 205, 304) | |
protected override def isContentAlwaysEmpty(msg: HttpMessage) = { | |
msg match { | |
case res: HttpResponse => codes contains res.getStatus.getCode | |
case _ => false | |
} | |
} | |
} | |
/** | |
* A WebSocket related exception | |
* | |
* Copied from https://github.com/cgbystrom/netty-tools | |
*/ | |
class WebSocketException(s: String, th: Throwable) extends java.io.IOException(s, th) { | |
def this(s: String) = this(s, null) | |
} | |
} | |
trait WebSocketClient { | |
def url: URI | |
def reader: WebSocketClient.FrameReader | |
def handler: WebSocketClient.Handler | |
def connect | |
def disconnect | |
def send(message: String, charset: Charset = CharsetUtil.UTF_8) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package io.backchat.minutes.river | |
import org.jboss.netty.bootstrap.ServerBootstrap | |
import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory | |
import java.util.concurrent.{ TimeUnit, Executors } | |
import java.net.{ InetSocketAddress } | |
import org.jboss.netty.channel._ | |
import group.{ChannelGroup, DefaultChannelGroup} | |
import org.elasticsearch.common.logging.{ESLogger, ESLoggerFactory} | |
import org.jboss.netty.handler.codec.http.{HttpRequest, HttpChunkAggregator, HttpRequestDecoder, HttpResponseEncoder} | |
import org.jboss.netty.handler.codec.http.websocketx._ | |
import org.jboss.netty.handler.codec.http.HttpHeaders.Values | |
import org.jboss.netty.handler.codec.http.HttpHeaders.Names | |
import java.util.Locale.ENGLISH | |
trait WebSocketServerConfig { | |
def listenOn: String | |
def port: Int | |
} | |
/** | |
* Netty based WebSocketServer | |
* requires netty 3.3.x or later | |
* | |
* Usage: | |
* <pre> | |
* val conf = new WebSocketServerConfig { | |
* val port = 14567 | |
* val listenOn = "0.0.0.0" | |
* } | |
* | |
* val server = WebSocketServer(conf) { | |
* case Connect(_) => println("got a client connection") | |
* case TextMessage(cl, text) => cl.write(new TextWebSocketFrame("ECHO: " + text)) | |
* case Disconnected(_) => println("client disconnected") | |
* } | |
* server.start | |
* // time passes...... | |
* server.stop | |
* </pre> | |
*/ | |
object WebSocketServer { | |
type WebSocketHandler = PartialFunction[WebSocketMessage, Unit] | |
sealed trait WebSocketMessage | |
case class Connect(client: Channel) extends WebSocketMessage | |
case class TextMessage(client: Channel, content: String) extends WebSocketMessage | |
case class BinaryMessage(client: Channel, content: Array[Byte]) extends WebSocketMessage | |
case class Error(client: Channel, cause: Option[Throwable]) extends WebSocketMessage | |
case class Disconnected(client: Channel) extends WebSocketMessage | |
def apply(config: WebSocketServerConfig)(handler: WebSocketServer.WebSocketHandler): WebSocketServer = | |
new WebSocketServer(config, handler) | |
private class ConnectionTracker(channels: ChannelGroup) extends SimpleChannelUpstreamHandler { | |
override def channelClosed(ctx: ChannelHandlerContext, e: ChannelStateEvent) { | |
channels remove e.getChannel | |
ctx.sendUpstream(e) | |
} | |
override def channelConnected(ctx: ChannelHandlerContext, e: ChannelStateEvent) { | |
channels add e.getChannel | |
ctx.sendUpstream(e) | |
} | |
override def channelDisconnected(ctx: ChannelHandlerContext, e: ChannelStateEvent) { | |
channels remove e.getChannel | |
ctx.sendUpstream(e) | |
} | |
} | |
private class WebSocketPartialFunctionHandler(handler: WebSocketHandler, logger: ESLogger) extends SimpleChannelUpstreamHandler { | |
private[this] var collectedFrames: Seq[ContinuationWebSocketFrame] = Vector.empty[ContinuationWebSocketFrame] | |
private[this] var handshaker: WebSocketServerHandshaker = _ | |
override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent) { | |
e.getMessage match { | |
case httpRequest: HttpRequest if isWebSocketUpgrade(httpRequest) ⇒ handleUpgrade(ctx, httpRequest) | |
case m: TextWebSocketFrame => handler lift TextMessage(e.getChannel, m.getText) | |
case m: BinaryWebSocketFrame => handler lift BinaryMessage(e.getChannel, m.getBinaryData.array) | |
case m: ContinuationWebSocketFrame => { | |
if (m.isFinalFragment) { | |
handler lift TextMessage(e.getChannel, collectedFrames map (_.getText) reduce (_ + _)) | |
collectedFrames = Nil | |
} else { | |
collectedFrames :+= m | |
} | |
} | |
case f: CloseWebSocketFrame ⇒ | |
if (handshaker != null) handshaker.close(ctx.getChannel, f) | |
handler lift Disconnected(e.getChannel) | |
case _: PingWebSocketFrame ⇒ e.getChannel.write(new PongWebSocketFrame) | |
case _ ⇒ ctx.sendUpstream(e) | |
} | |
} | |
override def exceptionCaught(ctx: ChannelHandlerContext, e: ExceptionEvent) { | |
handler lift Error(e.getChannel, Option(e.getCause)) | |
} | |
private def isWebSocketUpgrade(httpRequest: HttpRequest): Boolean = { | |
val connHdr = httpRequest.getHeader(Names.CONNECTION) | |
val upgrHdr = httpRequest.getHeader(Names.UPGRADE) | |
(connHdr != null && connHdr.equalsIgnoreCase(Values.UPGRADE)) && | |
(upgrHdr != null && upgrHdr.equalsIgnoreCase(Values.WEBSOCKET)) | |
} | |
private def handleUpgrade(ctx: ChannelHandlerContext, httpRequest: HttpRequest) { | |
val handshakerFactory = new WebSocketServerHandshakerFactory(websocketLocation(httpRequest), null, false) | |
handshaker = handshakerFactory.newHandshaker(httpRequest) | |
if (handshaker == null) handshakerFactory.sendUnsupportedWebSocketVersionResponse(ctx.getChannel) | |
else { | |
handshaker.handshake(ctx.getChannel, httpRequest) | |
handler.lift(Connect(ctx.getChannel)) | |
} | |
} | |
private def isHttps(req: HttpRequest) = { | |
val h1 = Option(req.getHeader("REQUEST_URI")).filter(_.trim.nonEmpty) | |
val h2 = Option(req.getHeader("REQUEST_URI")).filter(_.trim.nonEmpty) | |
(h1.isDefined && h1.forall(_.toUpperCase(ENGLISH).startsWith("HTTPS"))) || | |
(h2.isDefined && h2.forall(_.toUpperCase(ENGLISH) startsWith "HTTPS")) | |
} | |
private def websocketLocation(req: HttpRequest) = { | |
if (isHttps(req)) | |
"wss://" + req.getHeader(Names.HOST) + "/" | |
else | |
"ws://" + req.getHeader(Names.HOST) + "/" | |
} | |
} | |
} | |
class WebSocketServer(val config: WebSocketServerConfig, val handler: WebSocketServer.WebSocketHandler) { | |
import WebSocketServer._ | |
private[this] val realHandler = handler orElse devNull | |
private[this] val devNull: WebSocketHandler = { | |
case WebSocketServer.Error(_, Some(ex)) => | |
System.err.println(ex.getMessage) | |
ex.printStackTrace() | |
case _ => | |
} | |
protected val logger = ESLoggerFactory.getLogger(getClass.getName) | |
private[this] val boss = Executors.newCachedThreadPool() | |
private[this] val worker = Executors.newCachedThreadPool() | |
private[this] val server = { | |
val bs = new ServerBootstrap(new NioServerSocketChannelFactory(boss, worker)) | |
bs.setOption("soLinger", 0) | |
bs.setOption("reuseAddress", true) | |
bs.setOption("child.tcpNoDelay", true) | |
bs | |
} | |
private[this] val allChannels = new DefaultChannelGroup | |
protected def getPipeline = { | |
val pipe = Channels.pipeline() | |
pipe.addLast("connection-tracker", new ConnectionTracker(allChannels)) | |
pipe.addLast("decoder", new HttpRequestDecoder(4096, 8192, 8192)) | |
pipe.addLast("aggregator", new HttpChunkAggregator(64 * 1024)) | |
pipe.addLast("encoder", new HttpResponseEncoder) | |
pipe.addLast("websocketmessages", new WebSocketPartialFunctionHandler(realHandler, logger)) | |
pipe | |
} | |
private[this] val servName = getClass.getSimpleName | |
def start = synchronized { | |
server.setPipeline(getPipeline) | |
val addr = if (config.listenOn == null || config.listenOn.trim.isEmpty) new InetSocketAddress(config.port) | |
else new InetSocketAddress(config.listenOn, config.port) | |
val sc = server.bind(addr) | |
allChannels add sc | |
logger info "Started %s on [%s:%d]".format(servName, config.listenOn, config.port) | |
} | |
def stop = synchronized { | |
allChannels.close().awaitUninterruptibly() | |
val thread = new Thread { | |
override def run = { | |
server.releaseExternalResources() | |
boss.awaitTermination(5, TimeUnit.SECONDS) | |
worker.awaitTermination(5, TimeUnit.SECONDS) | |
} | |
} | |
thread.setDaemon(false) | |
thread.start() | |
thread.join() | |
logger info "Stopped %s".format(servName) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The client implementation creates a new Netty channel factory per WebSocket connection, which means basically you have one thread (actually, more likely 3, 1 boss thread, up to 2 worker threads) per connection, thus defeating the point of using NIO. The channel factory should be passed in to the DefaultWebSocketClient so it can be reused between client connections. If this change is done, then that will also solve the thread leak caused by DefaultWebSocketClient.disconnect not shutting down the thread pools in the channel factory.