Skip to content

Instantly share code, notes, and snippets.

@alexanderjarvis
Forked from casualjim/WebSocketClient.scala
Created April 23, 2013 17:40
Show Gist options
  • Save alexanderjarvis/5445720 to your computer and use it in GitHub Desktop.
Save alexanderjarvis/5445720 to your computer and use it in GitHub Desktop.
package mojolly.io
import org.jboss.netty.bootstrap.ClientBootstrap
import org.jboss.netty.channel._
import socket.nio.NioClientSocketChannelFactory
import java.util.concurrent.Executors
import org.jboss.netty.handler.codec.http._
import collection.JavaConversions._
import websocketx._
import java.net.{InetSocketAddress, URI}
import java.nio.charset.Charset
import org.jboss.netty.buffer.ChannelBuffers
import org.jboss.netty.util.CharsetUtil
import akka.actor.ActorRef
import mojolly.LibraryConstants
/**
* Usage of the simple websocket client:
* <pre>
* WebSocketClient(new URI("ws://localhost:8080/thesocket")) {
* case Connected(client) => println("Connection has been established to: " + client.url.toASCIIString)
* case Disconnected(client, _) => println("The websocket to " + client.url.toASCIIString + " disconnected.")
* case TextMessage(client, message) => {
* println("RECV: " + message)
* client send ("ECHO: " + message)
* }
* }
* </pre>
*/
object WebSocketClient {
object Messages {
sealed trait WebSocketClientMessage
case object Connecting extends WebSocketClientMessage
case class ConnectionFailed(client: WebSocketClient, reason: Option[Throwable] = None) extends WebSocketClientMessage
case class Connected(client: WebSocketClient) extends WebSocketClientMessage
case class TextMessage(client: WebSocketClient, text: String) extends WebSocketClientMessage
case class WriteFailed(client: WebSocketClient, message: String, reason: Option[Throwable]) extends WebSocketClientMessage
case object Disconnecting extends WebSocketClientMessage
case class Disconnected(client: WebSocketClient, reason: Option[Throwable] = None) extends WebSocketClientMessage
case class Error(client: WebSocketClient, th: Throwable) extends WebSocketClientMessage
}
type Handler = PartialFunction[Messages.WebSocketClientMessage, Unit]
type FrameReader = WebSocketFrame => String
val defaultFrameReader = (_: WebSocketFrame) match {
case f: TextWebSocketFrame => f.getText
case _ => throw new UnsupportedOperationException("Only single text frames are supported for now")
}
def apply(url: URI, version: WebSocketVersion = WebSocketVersion.V13, reader: FrameReader = defaultFrameReader)(handle: Handler): WebSocketClient = {
require(url.getScheme.startsWith("ws"), "The scheme of the url should be 'ws' or 'wss'")
new DefaultWebSocketClient(url, version, handle, reader)
}
def apply(url: URI, handle: ActorRef): WebSocketClient = {
require(url.getScheme.startsWith("ws"), "The scheme of the url should be 'ws' or 'wss'")
WebSocketClient(url) { case x => handle ! x }
}
private class WebSocketClientHandler(handshaker: WebSocketClientHandshaker, client: WebSocketClient) extends SimpleChannelUpstreamHandler {
import Messages._
override def channelClosed(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
client.handler(Disconnected(client))
}
override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent) {
e.getMessage match {
case resp: HttpResponse if handshaker.isHandshakeComplete =>
throw new WebSocketException("Unexpected HttpResponse (status=" + resp.getStatus + ", content="
+ resp.getContent.toString(CharsetUtil.UTF_8) + ")")
case resp: HttpResponse =>
handshaker.finishHandshake(ctx.getChannel, e.getMessage.asInstanceOf[HttpResponse])
client.handler(Connected(client))
case f: TextWebSocketFrame => client.handler(TextMessage(client, f.getText))
case _: PongWebSocketFrame =>
case _: CloseWebSocketFrame => ctx.getChannel.close()
}
}
override def exceptionCaught(ctx: ChannelHandlerContext, e: ExceptionEvent) {
client.handler(Error(client, e.getCause))
e.getChannel.close()
}
}
private class DefaultWebSocketClient(
val url: URI,
version: WebSocketVersion,
private[this] val _handler: Handler,
val reader: FrameReader = defaultFrameReader) extends WebSocketClient {
val normalized = url.normalize()
val tgt = if (normalized.getPath == null || normalized.getPath.trim().isEmpty) {
new URI(normalized.getScheme, normalized.getAuthority,"/", normalized.getQuery, normalized.getFragment)
} else normalized
val bootstrap = new ClientBootstrap(new NioClientSocketChannelFactory(Executors.newCachedThreadPool, Executors.newCachedThreadPool))
val handshaker = new WebSocketClientHandshakerFactory().newHandshaker(tgt, version, null, false, Map.empty[String, String])
val self = this
var channel: Channel = _
import Messages._
val handler = _handler orElse defaultHandler
private def defaultHandler: Handler = {
case Error(_, ex) => ex.printStackTrace()
case _: WebSocketClientMessage =>
}
bootstrap.setPipelineFactory(new ChannelPipelineFactory {
def getPipeline = {
val pipeline = Channels.pipeline()
if (version == WebSocketVersion.V00)
pipeline.addLast("decoder", new WebSocketHttpResponseDecoder)
else
pipeline.addLast("decoder", new HttpResponseDecoder)
pipeline.addLast("encoder", new HttpRequestEncoder)
pipeline.addLast("ws-handler", new WebSocketClientHandler(handshaker, self))
pipeline
}
})
import WebSocketClient.Messages._
def connect = {
if (channel == null || !channel.isConnected) {
val listener = futureListener { future =>
if (future.isSuccess) {
synchronized { channel = future.getChannel }
handshaker.handshake(channel)
} else {
handler(ConnectionFailed(this, Option(future.getCause)))
}
}
handler(Connecting)
val fut = bootstrap.connect(new InetSocketAddress(url.getHost, url.getPort))
fut.addListener(listener)
fut.await(5000L)
}
}
def disconnect = {
if (channel != null && channel.isConnected) {
handler(Disconnecting)
channel.write(new CloseWebSocketFrame())
}
}
def send(message: String, charset: Charset = CharsetUtil.UTF_8) = {
channel.write(new TextWebSocketFrame(ChannelBuffers.copiedBuffer(message, charset))).addListener(futureListener { fut =>
if (!fut.isSuccess) {
handler(WriteFailed(this, message, Option(fut.getCause)))
}
})
}
def futureListener(handleWith: ChannelFuture => Unit) = new ChannelFutureListener {
def operationComplete(future: ChannelFuture) {handleWith(future)}
}
}
/**
* Fix bug in standard HttpResponseDecoder for web socket clients. When status 101 is received for Hybi00, there are 16
* bytes of contents expected
*/
class WebSocketHttpResponseDecoder extends HttpResponseDecoder {
val codes = List(101, 200, 204, 205, 304)
protected override def isContentAlwaysEmpty(msg: HttpMessage) = {
msg match {
case res: HttpResponse => codes contains res.getStatus.getCode
case _ => false
}
}
}
/**
* A WebSocket related exception
*
* Copied from https://github.com/cgbystrom/netty-tools
*/
class WebSocketException(s: String, th: Throwable) extends java.io.IOException(s, th) {
def this(s: String) = this(s, null)
}
}
trait WebSocketClient {
def url: URI
def reader: WebSocketClient.FrameReader
def handler: WebSocketClient.Handler
def connect
def disconnect
def send(message: String, charset: Charset = CharsetUtil.UTF_8)
}
package io.backchat.minutes.river
import org.jboss.netty.bootstrap.ServerBootstrap
import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory
import java.util.concurrent.{ TimeUnit, Executors }
import java.net.{ InetSocketAddress }
import org.jboss.netty.channel._
import group.{ChannelGroup, DefaultChannelGroup}
import org.elasticsearch.common.logging.{ESLogger, ESLoggerFactory}
import org.jboss.netty.handler.codec.http.{HttpRequest, HttpChunkAggregator, HttpRequestDecoder, HttpResponseEncoder}
import org.jboss.netty.handler.codec.http.websocketx._
import org.jboss.netty.handler.codec.http.HttpHeaders.Values
import org.jboss.netty.handler.codec.http.HttpHeaders.Names
import java.util.Locale.ENGLISH
trait WebSocketServerConfig {
def listenOn: String
def port: Int
}
/**
* Netty based WebSocketServer
* requires netty 3.3.x or later
*
* Usage:
* <pre>
* val conf = new WebSocketServerConfig {
* val port = 14567
* val listenOn = "0.0.0.0"
* }
*
* val server = WebSocketServer(conf) {
* case Connect(_) => println("got a client connection")
* case TextMessage(cl, text) => cl.write(new TextWebSocketFrame("ECHO: " + text))
* case Disconnected(_) => println("client disconnected")
* }
* server.start
* // time passes......
* server.stop
* </pre>
*/
object WebSocketServer {
type WebSocketHandler = PartialFunction[WebSocketMessage, Unit]
sealed trait WebSocketMessage
case class Connect(client: Channel) extends WebSocketMessage
case class TextMessage(client: Channel, content: String) extends WebSocketMessage
case class BinaryMessage(client: Channel, content: Array[Byte]) extends WebSocketMessage
case class Error(client: Channel, cause: Option[Throwable]) extends WebSocketMessage
case class Disconnected(client: Channel) extends WebSocketMessage
def apply(config: WebSocketServerConfig)(handler: WebSocketServer.WebSocketHandler): WebSocketServer =
new WebSocketServer(config, handler)
private class ConnectionTracker(channels: ChannelGroup) extends SimpleChannelUpstreamHandler {
override def channelClosed(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
channels remove e.getChannel
ctx.sendUpstream(e)
}
override def channelConnected(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
channels add e.getChannel
ctx.sendUpstream(e)
}
override def channelDisconnected(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
channels remove e.getChannel
ctx.sendUpstream(e)
}
}
private class WebSocketPartialFunctionHandler(handler: WebSocketHandler, logger: ESLogger) extends SimpleChannelUpstreamHandler {
private[this] var collectedFrames: Seq[ContinuationWebSocketFrame] = Vector.empty[ContinuationWebSocketFrame]
private[this] var handshaker: WebSocketServerHandshaker = _
override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent) {
e.getMessage match {
case httpRequest: HttpRequest if isWebSocketUpgrade(httpRequest) ⇒ handleUpgrade(ctx, httpRequest)
case m: TextWebSocketFrame => handler lift TextMessage(e.getChannel, m.getText)
case m: BinaryWebSocketFrame => handler lift BinaryMessage(e.getChannel, m.getBinaryData.array)
case m: ContinuationWebSocketFrame => {
if (m.isFinalFragment) {
handler lift TextMessage(e.getChannel, collectedFrames map (_.getText) reduce (_ + _))
collectedFrames = Nil
} else {
collectedFrames :+= m
}
}
case f: CloseWebSocketFrame ⇒
if (handshaker != null) handshaker.close(ctx.getChannel, f)
handler lift Disconnected(e.getChannel)
case _: PingWebSocketFrame ⇒ e.getChannel.write(new PongWebSocketFrame)
case _ ⇒ ctx.sendUpstream(e)
}
}
override def exceptionCaught(ctx: ChannelHandlerContext, e: ExceptionEvent) {
handler lift Error(e.getChannel, Option(e.getCause))
}
private def isWebSocketUpgrade(httpRequest: HttpRequest): Boolean = {
val connHdr = httpRequest.getHeader(Names.CONNECTION)
val upgrHdr = httpRequest.getHeader(Names.UPGRADE)
(connHdr != null && connHdr.equalsIgnoreCase(Values.UPGRADE)) &&
(upgrHdr != null && upgrHdr.equalsIgnoreCase(Values.WEBSOCKET))
}
private def handleUpgrade(ctx: ChannelHandlerContext, httpRequest: HttpRequest) {
val handshakerFactory = new WebSocketServerHandshakerFactory(websocketLocation(httpRequest), null, false)
handshaker = handshakerFactory.newHandshaker(httpRequest)
if (handshaker == null) handshakerFactory.sendUnsupportedWebSocketVersionResponse(ctx.getChannel)
else {
handshaker.handshake(ctx.getChannel, httpRequest)
handler.lift(Connect(ctx.getChannel))
}
}
private def isHttps(req: HttpRequest) = {
val h1 = Option(req.getHeader("REQUEST_URI")).filter(_.trim.nonEmpty)
val h2 = Option(req.getHeader("REQUEST_URI")).filter(_.trim.nonEmpty)
(h1.isDefined && h1.forall(_.toUpperCase(ENGLISH).startsWith("HTTPS"))) ||
(h2.isDefined && h2.forall(_.toUpperCase(ENGLISH) startsWith "HTTPS"))
}
private def websocketLocation(req: HttpRequest) = {
if (isHttps(req))
"wss://" + req.getHeader(Names.HOST) + "/"
else
"ws://" + req.getHeader(Names.HOST) + "/"
}
}
}
class WebSocketServer(val config: WebSocketServerConfig, val handler: WebSocketServer.WebSocketHandler) {
import WebSocketServer._
private[this] val realHandler = handler orElse devNull
private[this] val devNull: WebSocketHandler = {
case WebSocketServer.Error(_, Some(ex)) =>
System.err.println(ex.getMessage)
ex.printStackTrace()
case _ =>
}
protected val logger = ESLoggerFactory.getLogger(getClass.getName)
private[this] val boss = Executors.newCachedThreadPool()
private[this] val worker = Executors.newCachedThreadPool()
private[this] val server = {
val bs = new ServerBootstrap(new NioServerSocketChannelFactory(boss, worker))
bs.setOption("soLinger", 0)
bs.setOption("reuseAddress", true)
bs.setOption("child.tcpNoDelay", true)
bs
}
private[this] val allChannels = new DefaultChannelGroup
protected def getPipeline = {
val pipe = Channels.pipeline()
pipe.addLast("connection-tracker", new ConnectionTracker(allChannels))
pipe.addLast("decoder", new HttpRequestDecoder(4096, 8192, 8192))
pipe.addLast("aggregator", new HttpChunkAggregator(64 * 1024))
pipe.addLast("encoder", new HttpResponseEncoder)
pipe.addLast("websocketmessages", new WebSocketPartialFunctionHandler(realHandler, logger))
pipe
}
private[this] val servName = getClass.getSimpleName
def start = synchronized {
server.setPipeline(getPipeline)
val addr = if (config.listenOn == null || config.listenOn.trim.isEmpty) new InetSocketAddress(config.port)
else new InetSocketAddress(config.listenOn, config.port)
val sc = server.bind(addr)
allChannels add sc
logger info "Started %s on [%s:%d]".format(servName, config.listenOn, config.port)
}
def stop = synchronized {
allChannels.close().awaitUninterruptibly()
val thread = new Thread {
override def run = {
server.releaseExternalResources()
boss.awaitTermination(5, TimeUnit.SECONDS)
worker.awaitTermination(5, TimeUnit.SECONDS)
}
}
thread.setDaemon(false)
thread.start()
thread.join()
logger info "Stopped %s".format(servName)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment