Created
March 3, 2015 12:49
-
-
Save bantonsson/881f831db93ec474f9bd to your computer and use it in GitHub Desktop.
Akka HTTP 1.0-M4 File upload using a form
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com> | |
*/ | |
import akka.actor.ActorSystem | |
import akka.http.Http | |
import akka.http.Http.IncomingConnection | |
import akka.http.model.HttpEntity | |
import akka.http.model.Multipart.FormData | |
import akka.http.server.{ Directives, Route, RoutingSetup } | |
import akka.http.unmarshalling.Unmarshal | |
import akka.stream.{ ActorFlowMaterializer, FlowMaterializer } | |
import akka.stream.scaladsl.{ Keep, Flow, Sink, Source } | |
import akka.testkit.TestKit | |
import akka.util.ByteString | |
import com.typesafe.config.{ Config, ConfigFactory } | |
import java.io.{ OutputStream, BufferedOutputStream, FileOutputStream, File } | |
import java.net.InetSocketAddress | |
import java.nio.channels.ServerSocketChannel | |
import org.scalatest.concurrent.Timeouts | |
import org.scalatest.time.{ Milliseconds, Span } | |
import org.scalatest.{ BeforeAndAfterAll, Matchers, WordSpecLike } | |
import scala.annotation.tailrec | |
import scala.concurrent.duration._ | |
import scala.concurrent.{ Await, ExecutionContext, Future } | |
import scala.sys.process._ | |
import scala.util.Random | |
object FormUpload { | |
val debug = true | |
def temporaryServerAddress(interface: String = "127.0.0.1"): InetSocketAddress = { | |
val serverSocket = ServerSocketChannel.open() | |
try { | |
serverSocket.socket.bind(new InetSocketAddress(interface, 0)) | |
val port = serverSocket.socket.getLocalPort | |
new InetSocketAddress(interface, port) | |
} finally serverSocket.close() | |
} | |
def temporaryServerAddressAndPort(interface: String = "127.0.0.1"): (String, Int) = { | |
val socketAddress = temporaryServerAddress(interface) | |
socketAddress.getAddress.getHostAddress -> socketAddress.getPort | |
} | |
def withHttpServer(route: Route)(block: (String, Int) ⇒ Unit)( | |
implicit system: ActorSystem, materializer: ActorFlowMaterializer, setup: RoutingSetup): Unit = { | |
val (address, port) = temporaryServerAddressAndPort() | |
val connectionSource = Http().bind(address, port) | |
val handleSink = Flow[IncomingConnection].toMat(Sink.foreach(_.handleWith(route)))(Keep.right) | |
val (bindingFuture, completionFuture) = connectionSource.toMat(handleSink)(Keep.both).run() | |
val binding = Await.result(bindingFuture, 5 seconds) | |
try { | |
block(address, port) | |
} finally { | |
binding.unbind() | |
} | |
} | |
val blobPartName = "theBlob" | |
case class BlobPart(filename: String, source: Source[ByteString, Unit]) | |
def readBlob(bodyParts: Source[FormData.BodyPart, Unit])(implicit mat: FlowMaterializer, ec: ExecutionContext): Future[String] = { | |
def readBlob(blob: BlobPart): Long = { | |
if (blob.filename.nonEmpty) { | |
Await.result(blob.source.runFold(0L) { | |
case (sum, bytes) ⇒ | |
val newSum = sum + bytes.size | |
if (debug) println(s"Total $newSum after reading ${bytes.size}") | |
newSum | |
}, maxDuration) | |
} else 0 | |
} | |
val blobParts = bodyParts.collect { | |
case part @ FormData.BodyPart(name, entitiy, _, _) ⇒ | |
if (name == blobPartName && part.filename.isDefined) | |
BlobPart(part.filename.get, entitiy.dataBytes) | |
else { | |
// ignore form parts that we don't recognize | |
// (potential DOS since it will read all data | |
entitiy.getDataBytes.runWith(Sink.ignore()) | |
BlobPart("", Source.empty()) | |
} | |
} | |
blobParts.map(readBlob).map(_.toString + " ").runFold("")(_ + _) | |
} | |
def createBlobOfSize(size: Long): String = { | |
val bytesSize: Int = 1024 | |
val bytes = Array.ofDim[Byte](bytesSize) | |
@tailrec | |
def writeRandomBytes(random: Random, remaining: Long, output: OutputStream): Unit = remaining match { | |
case r if r > 0 ⇒ | |
random.nextBytes(bytes) | |
val writeSize = if (r > bytesSize) bytesSize else r.toInt | |
output.write(bytes, 0, writeSize) | |
writeRandomBytes(random, r - writeSize, output) | |
case r if r == 0 ⇒ // done | |
case r ⇒ throw new IllegalArgumentException(s"Cant write less than 0 bytes [$remaining]") | |
} | |
val rnd = new Random(size) | |
val name = rnd.nextLong.toHexString | |
val temp = File.createTempFile(name, ".tmp") | |
val out = new BufferedOutputStream(new FileOutputStream(temp)) | |
try { | |
writeRandomBytes(rnd, size, out) | |
} finally { | |
out.close() | |
} | |
temp.getCanonicalPath | |
} | |
val testConfig: Config = ConfigFactory.parseString(""" | |
akka.event-handlers = ["akka.testkit.TestEventListener"] | |
akka.loglevel = DEBUG""") | |
val maxDuration = 10 seconds | |
} | |
class FormUpload extends WordSpecLike with Matchers with BeforeAndAfterAll with Directives with Timeouts { | |
import FormUpload._ | |
implicit val system = ActorSystem(getClass.getSimpleName, testConfig) | |
var blobPath: String = "" | |
override def beforeAll { | |
blobPath = createBlobOfSize(1024 * 1024 * 2.3 toLong) | |
} | |
final override def afterAll { | |
new File(blobPath).delete() | |
TestKit.shutdownActorSystem(system) | |
} | |
import system.dispatcher | |
implicit val materializer = ActorFlowMaterializer() | |
"An HTTP Server" should { | |
"Accept a simple large multipart form upload" in { | |
withHttpServer( | |
// format: OFF | |
path("form") { | |
post { | |
extractRequest { request => | |
complete(loadSimpleForm(request.entity)) | |
} | |
} | |
// format: ON | |
}) { (address, port) ⇒ | |
failAfter(Span(maxDuration.toMillis, Milliseconds)) { | |
uploadSimpleForm(address, port) should be(0) | |
} | |
} | |
} | |
} | |
private def uploadSimpleForm(address: String, port: Int): Int = { | |
val command = List( | |
"curl", | |
"-v", // verbose | |
"-H", "Expect:", // remove Expect: 100-continue | |
"--form", s"""ignoredFormField=@"$blobPath"""", | |
"--form", s"""$blobPartName=@"$blobPath"""", | |
s"http://$address:$port/form") | |
if (debug) println(s"About to execute: ${command.mkString(" ")}") | |
command.! | |
} | |
private def loadSimpleForm(entity: HttpEntity): Future[String] = | |
for { | |
multiPartFormData ← Unmarshal(entity).to[FormData] | |
blobResult ← readBlob(multiPartFormData.parts) | |
} yield "Read the blob " + blobResult | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hello Björn, thanks for this piece of code.
I made it work, however, I quickly ran into:
So I changed akka.http.server.parsing.max-content-length to have a greater value.
However, if I increase the size of the blob I am sending (with httpie), I quickly run into:
It seems the Http component 1) parses the whole request (=> OOM) then 2) streams it down.
Am I missing something? Is there any way to stream in the whole chain?
Thanks again,
Sylvain