Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save joyoyoyoyoyo/8ede3cdfd1d2c8cf1ac42042b9e4e525 to your computer and use it in GitHub Desktop.

Select an option

Save joyoyoyoyoyo/8ede3cdfd1d2c8cf1ac42042b9e4e525 to your computer and use it in GitHub Desktop.
Implementation of a connection pool for use with spark streaming. See http://stackoverflow.com/questions/30450763/spark-streaming-and-connection-pool-implementation
package net.atos.sparti.pub
import java.io.PrintStream
import java.net.Socket
import org.apache.commons.pool2.impl.{DefaultPooledObject, GenericObjectPool}
import org.apache.commons.pool2.{ObjectPool, PooledObject, BasePooledObjectFactory}
import org.apache.spark.streaming.dstream.DStream
class PooledSocketStreamPublisher[T](host: String, port: Int)
extends Serializable {
/**
* Publish the stream to a socket.
*/
def publish (stream: DStream[T], callback: (T) => String) =
stream foreachRDD ( rdd =>
rdd foreachPartition { partition =>
val pool = PrintStreamPool(host, port)
partition foreach { event =>
val s = pool.printStream
s println callback (event)
}
pool.release()
}
)
}
class ManagedPrintStream(private val pool: ObjectPool[PrintStream], val printStream: PrintStream) {
def release() = pool.returnObject(printStream)
}
object PrintStreamPool {
var hostPortPool: Map[(String, Int), ObjectPool[PrintStream]] = Map()
sys.addShutdownHook {
hostPortPool.values.foreach { pool => pool.close() }
}
// factory method
def apply(host: String, port: Int): ManagedPrintStream = {
val pool = hostPortPool.getOrElse((host, port), {
val p = new GenericObjectPool[PrintStream](new SocketStreamFactory(host, port))
hostPortPool += (host, port) -> p
p
})
new ManagedPrintStream(pool, pool.borrowObject())
}
}
class SocketStreamFactory(host: String, port: Int) extends BasePooledObjectFactory[PrintStream] {
override def create() = new PrintStream(new Socket(host, port).getOutputStream)
override def wrap(stream: PrintStream) = new DefaultPooledObject[PrintStream](stream)
override def validateObject(po: PooledObject[PrintStream]) = ! po.getObject.checkError()
override def destroyObject(po: PooledObject[PrintStream]) = po.getObject.close()
override def passivateObject(po: PooledObject[PrintStream]) = po.getObject.flush()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment