Last active
December 21, 2015 10:48
-
-
Save mathieuancelin/6294116 to your computer and use it in GitHub Desktop.
Scala class to transform an HTTP GET stream into Enumerator[Whatever]
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package play.api.libs.ws | |
import play.api.libs.iteratee.{Enumeratee, Concurrent, Enumerator} | |
import play.api.libs.concurrent.Execution.Implicits._ | |
import com.ning.http.client._ | |
import com.ning.http.client.AsyncHandler.STATE | |
import play.api.Logger | |
import scala.concurrent.{Future, Promise} | |
object WSEnumerator { | |
private class AbortOnIterateeDone() extends RuntimeException | |
private val logger = Logger("WSEnumerator") | |
def getStream[A](url: String, timeout: Int = -1)(f: Array[Byte] => A): Future[Enumerator[A]] = { | |
getRawStream( url, timeout ).map( _.through( Enumeratee.map[Array[Byte]]( bytes => f( bytes ) ) ) ) | |
} | |
def getRawStream(url: String, timeout: Int = -1): Future[Enumerator[Array[Byte]]] = { | |
val promise = Promise[Enumerator[Array[Byte]]]() | |
val promiseStatus = Promise[Int]() | |
val promiseHeader = Promise[HttpResponseHeaders]() | |
val config = new PerRequestConfig() | |
config.setRequestTimeoutInMs(timeout) | |
val client = WS.client.prepareGet(url).setPerRequestConfig(config) | |
val (enumerator, channel) = Concurrent.broadcast[Array[Byte]] | |
val listenableFuture = client.execute(new AsyncHandler[Unit]() { | |
override def onThrowable(p1: Throwable) { | |
p1 match { | |
case _: AbortOnIterateeDone => logger.debug(s"WS call aborted on purpose : $p1") | |
case _ => { | |
logger.debug("Actual exception, closing enumerator channel and leaking exception") | |
channel.eofAndEnd() | |
throw p1 | |
} | |
} | |
} | |
override def onBodyPartReceived(p1: HttpResponseBodyPart): STATE = { | |
channel.push(p1.getBodyPartBytes) | |
STATE.CONTINUE | |
} | |
override def onStatusReceived(p1: HttpResponseStatus): STATE = { | |
if (p1.getStatusCode >= 300) { | |
promiseStatus.failure(new IllegalStateException(s"HTTP status is ${p1.getStatusCode} for URL ${url}")) | |
} else { | |
promiseStatus.success(p1.getStatusCode) | |
} | |
STATE.CONTINUE | |
} | |
override def onHeadersReceived(p1: HttpResponseHeaders): STATE = { | |
promiseHeader.success(p1) | |
STATE.CONTINUE | |
} | |
override def onCompleted() { | |
logger.debug("Closing channel as WS call is completed") | |
channel.eofAndEnd() | |
} | |
}) | |
promise.success(enumerator.through(Enumeratee.onIterateeDone[Array[Byte]]{ () => | |
logger.debug("Iteratee is done ...") | |
if (!listenableFuture.isDone) { | |
listenableFuture.abort(new AbortOnIterateeDone()) | |
channel.eofAndEnd() | |
logger.debug("Aborting WS call") | |
} else { | |
logger.debug("WS Call already finished") | |
} | |
})) | |
for { | |
_ <- promiseStatus.future | |
_ <- promiseHeader.future | |
f <- promise.future | |
} yield f | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment