Created
October 22, 2011 21:20
-
-
Save casualjim/1306507 to your computer and use it in GitHub Desktop.
CORS Support for scalatra
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package backchat | |
package web | |
import javax.servlet.http.{ HttpServletResponse, HttpServletRequest } | |
import org.scalatra._ | |
import collection.JavaConversions._ | |
object CORSSupport { | |
val ORIGIN_HEADER: String = "Origin" | |
val ACCESS_CONTROL_REQUEST_METHOD_HEADER: String = "Access-Control-Request-Method" | |
val ACCESS_CONTROL_REQUEST_HEADERS_HEADER: String = "Access-Control-Request-Headers" | |
val ACCESS_CONTROL_ALLOW_ORIGIN_HEADER: String = "Access-Control-Allow-Origin" | |
val ACCESS_CONTROL_ALLOW_METHODS_HEADER: String = "Access-Control-Allow-Methods" | |
val ACCESS_CONTROL_ALLOW_HEADERS_HEADER: String = "Access-Control-Allow-Headers" | |
val ACCESS_CONTROL_MAX_AGE_HEADER: String = "Access-Control-Max-Age" | |
val ACCESS_CONTROL_ALLOW_CREDENTIALS_HEADER: String = "Access-Control-Allow-Credentials" | |
// private val ACCESS_CONTROL_EXPOSE_HEADERS_HEADER = "Access-Control-Expose-Headers" | |
private val ANY_ORIGIN: String = "*" | |
private val SIMPLE_HEADERS = List(ORIGIN_HEADER.toUpperCase(ENGLISH), "ACCEPT", "ACCEPT-LANGUAGE", "CONTENT-LANGUAGE") | |
private val SIMPLE_CONTENT_TYPES = List("APPLICATION/X-WWW-FORM-URLENCODED", "MULTIPART/FORM-DATA", "TEXT/PLAIN") | |
val CORS_HEADERS = List( | |
ORIGIN_HEADER, | |
ACCESS_CONTROL_ALLOW_CREDENTIALS_HEADER, | |
ACCESS_CONTROL_ALLOW_HEADERS_HEADER, | |
ACCESS_CONTROL_ALLOW_METHODS_HEADER, | |
ACCESS_CONTROL_ALLOW_ORIGIN_HEADER, | |
ACCESS_CONTROL_MAX_AGE_HEADER, | |
ACCESS_CONTROL_REQUEST_HEADERS_HEADER, | |
ACCESS_CONTROL_REQUEST_METHOD_HEADER) | |
// private val SIMPLE_RESPONSE_HEADERS = List("CACHE-CONTROL", "CONTENT-LANGUAGE", "EXPIRES", "LAST-MODIFIED", "PRAGMA", "CONTENT-TYPE") | |
} | |
trait CORSSupport extends Handler { self: ScalatraKernel with Logging ⇒ | |
import CORSSupport._ | |
protected def corsConfig = Config.CORS | |
private val anyOriginAllowed: Boolean = corsConfig.allowedOrigins.contains(ANY_ORIGIN) | |
private val allowedOrigins = corsConfig.allowedOrigins | |
private val allowedMethods = corsConfig.allowedMethods | |
private val allowedHeaders = corsConfig.allowedHeaders | |
private val preflightMaxAge: Int = corsConfig.preflightMaxAge | |
private val allowCredentials: Boolean = corsConfig.allowCredentials | |
logger debug "Enabled CORS Support with:\nallowedOrigins: %s\nallowedMethods: %s\nallowedHeaders: %s".format( | |
allowedOrigins mkString ", ", | |
allowedMethods mkString ", ", | |
allowedHeaders mkString ", ") | |
protected def handlePreflightRequest() { | |
logger trace "handling preflight request" | |
// 5.2.7 | |
augmentSimpleRequest() | |
// 5.2.8 | |
if (preflightMaxAge > 0) response.setHeader(ACCESS_CONTROL_MAX_AGE_HEADER, preflightMaxAge.toString) | |
// 5.2.9 | |
response.setHeader(ACCESS_CONTROL_ALLOW_METHODS_HEADER, allowedMethods mkString ",") | |
// 5.2.10 | |
response.setHeader(ACCESS_CONTROL_ALLOW_HEADERS_HEADER, allowedHeaders mkString ",") | |
response.flushBuffer() | |
response.getOutputStream.flush() | |
} | |
protected def augmentSimpleRequest() { | |
val hdr = if (anyOriginAllowed && !allowCredentials) ANY_ORIGIN else request.getHeader(ORIGIN_HEADER) | |
response.setHeader(ACCESS_CONTROL_ALLOW_ORIGIN_HEADER, hdr) | |
if (allowCredentials) response.setHeader(ACCESS_CONTROL_ALLOW_CREDENTIALS_HEADER, "true") | |
/* | |
if (allowedHeaders.nonEmpty) { | |
val hdrs = allowedHeaders.filterNot(hn => SIMPLE_RESPONSE_HEADERS.contains(hn.toUpperCase(ENGLISH))).mkString(",") | |
response.addHeader(ACCESS_CONTROL_ALLOW_HEADERS_HEADER, hdrs) | |
} | |
*/ | |
} | |
private def originMatches = // 6.2.2 | |
anyOriginAllowed || (allowedOrigins contains request.getHeader(ORIGIN_HEADER)) | |
private def isEnabled = | |
!("Upgrade".equalsIgnoreCase(request.getHeader("Connection")) && | |
"WebSocket".equalsIgnoreCase(request.getHeader("Upgrade"))) && | |
!requestPath.contains("eb_ping") // don't do anything for the ping endpoint | |
private def isValidRoute: Boolean = routes.matchingMethods.nonEmpty | |
private def isPreflightRequest = { | |
val isCors = isCORSRequest | |
val validRoute = isValidRoute | |
val isPreflight = request.getHeader(ACCESS_CONTROL_REQUEST_METHOD_HEADER).isNotBlank | |
val enabled = isEnabled | |
val matchesOrigin = originMatches | |
val methodAllowd = allowsMethod | |
val allowsHeaders = headersAreAllowed | |
val result = isCors && validRoute && isPreflight && enabled && matchesOrigin && methodAllowd && allowsHeaders | |
logger trace "This is a preflight validation check. valid? %s".format(result) | |
logger trace "cors? %s, route? %s, preflight? %s, enabled? %s, origin? %s, method? %s, header? %s".format( | |
isCors, validRoute, isPreflight, enabled, matchesOrigin, methodAllowd, allowsHeaders) | |
result | |
} | |
private def isCORSRequest = { // 6.x.1 | |
val h = request.getHeader(ORIGIN_HEADER) | |
val result = h.isNotBlank | |
if (!result) logger trace ("No origin found in the request") | |
else logger trace ("We found the origin: %s".format(h)) | |
result | |
} | |
private def isSimpleHeader(header: String) = { | |
val ho = header.toOption | |
ho.isDefined && (ho forall { h ⇒ | |
val hu = h.toUpperCase(ENGLISH) | |
SIMPLE_HEADERS.contains(hu) || (hu == "CONTENT-TYPE" && | |
SIMPLE_CONTENT_TYPES.exists(request.getContentType.toUpperCase(ENGLISH).startsWith)) | |
}) | |
} | |
private def allOriginsMatch = { // 6.1.2 | |
val h = request.getHeader(ORIGIN_HEADER).toOption | |
h.isDefined && h.get.split(" ").nonEmpty && h.get.split(" ").forall(allowedOrigins.contains) | |
} | |
private def isSimpleRequest = { | |
val isCors = isCORSRequest | |
val enabled = isEnabled | |
val allOrigins = allOriginsMatch | |
val res = isCors && enabled && allOrigins && request.getHeaderNames.forall(isSimpleHeader) | |
logger trace "This is a simple request: %s, because: %s, %s, %s".format(res, isCors, enabled, allOrigins) | |
res | |
} | |
private def allowsMethod = { // 5.2.3 and 5.2.5 | |
val accessControlRequestMethod = request.getHeader(ACCESS_CONTROL_REQUEST_METHOD_HEADER) | |
logger.trace("%s is %s" format (ACCESS_CONTROL_REQUEST_METHOD_HEADER, accessControlRequestMethod)) | |
val result = accessControlRequestMethod.isNotBlank && allowedMethods.contains(accessControlRequestMethod.toUpperCase(ENGLISH)) | |
logger.trace("Method %s is %s among allowed methods %s".format(accessControlRequestMethod, if (result) "" else " not", allowedMethods)) | |
result | |
} | |
private def headersAreAllowed = { // 5.2.4 and 5.2.6 | |
val accessControlRequestHeaders = request.getHeader(ACCESS_CONTROL_REQUEST_HEADERS_HEADER).toOption | |
logger.trace("%s is %s".format(ACCESS_CONTROL_REQUEST_HEADERS_HEADER, accessControlRequestHeaders)) | |
val ah = (allowedHeaders ++ CORS_HEADERS).map(_.trim.toUpperCase(ENGLISH)) | |
val result = accessControlRequestHeaders forall { hdr ⇒ | |
val hdrs = hdr.split(",").map(_.trim.toUpperCase(ENGLISH)) | |
logger.debug("Headers [%s]".format(hdrs)) | |
(hdrs.nonEmpty && hdrs.forall { h ⇒ ah.contains(h) }) || isSimpleHeader(hdr) | |
} | |
logger.trace("Headers [%s] are %s among allowed headers %s".format( | |
accessControlRequestHeaders getOrElse "No headers", if (result) "" else " not", ah)) | |
result | |
} | |
abstract override def handle(req: HttpServletRequest, res: HttpServletResponse) { | |
_request.withValue(req) { | |
logger trace "the headers are: %s".format(req.getHeaderNames.mkString(", ")) | |
_response.withValue(res) { | |
request.method match { | |
case Options if isPreflightRequest ⇒ { | |
handlePreflightRequest() | |
} | |
case Get | Post | Head if isSimpleRequest ⇒ { | |
augmentSimpleRequest() | |
super.handle(req, res) | |
} | |
case _ if isCORSRequest ⇒ { | |
augmentSimpleRequest() | |
super.handle(req, res) | |
} | |
case _ ⇒ super.handle(req, res) | |
} | |
} | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package backchat | |
package web | |
package tests | |
import org.scalatra.test.specs2.ScalatraSpec | |
import org.scalatra.ScalatraServlet | |
class CORSSupportSpec extends ScalatraSpec { | |
addServlet(new ScalatraServlet with Logging with CORSSupport { | |
override protected lazy val corsConfig = | |
CORSConfig(List("http://www.example.com"), List("GET", "HEAD", "POST"), "X-Requested-With,Authorization,Content-Type,Accept,Origin".split(","), true) | |
get("/") { | |
"OK" | |
} | |
}, "/*") | |
def is = | |
"The CORS support should" ^ | |
"augment a valid simple request" ! context.validSimpleRequest ^ | |
"not touch a regular request" ! context.dontTouchRegularRequest ^ | |
"respond to a valid preflight request" ! context.validPreflightRequest ^ | |
"respond to a valid preflight request with headers" ! context.validPreflightRequestWithHeaders ^ end | |
object context { | |
def validSimpleRequest = { | |
get("/", headers = Map(CORSSupport.ORIGIN_HEADER -> "http://www.example.com")) { | |
response.getHeader(CORSSupport.ACCESS_CONTROL_ALLOW_ORIGIN_HEADER) must_== "http://www.example.com" | |
} | |
} | |
def dontTouchRegularRequest = { | |
get("/") { | |
response.getHeader(CORSSupport.ACCESS_CONTROL_ALLOW_ORIGIN_HEADER) must beNull | |
} | |
} | |
def validPreflightRequest = { | |
options("/", headers = Map(CORSSupport.ORIGIN_HEADER -> "http://www.example.com", CORSSupport.ACCESS_CONTROL_REQUEST_METHOD_HEADER -> "GET", "Content-Type" -> "application/json")) { | |
response.getHeader(CORSSupport.ACCESS_CONTROL_ALLOW_ORIGIN_HEADER) must_== "http://www.example.com" | |
} | |
} | |
def validPreflightRequestWithHeaders = { | |
val hdrs = Map( | |
CORSSupport.ORIGIN_HEADER -> "http://www.example.com", | |
CORSSupport.ACCESS_CONTROL_REQUEST_METHOD_HEADER -> "GET", | |
CORSSupport.ACCESS_CONTROL_REQUEST_HEADERS_HEADER -> "Origin, Authorization, Accept", | |
"Content-Type" -> "application/json") | |
options("/", headers = hdrs) { | |
response.getHeader(CORSSupport.ACCESS_CONTROL_ALLOW_ORIGIN_HEADER) must_== "http://www.example.com" | |
} | |
} | |
} | |
} |
CORS is Cross-Origin-Resource-Sharing, somthing browsers implement to allow you to make cross domain requests to servers with which you have a "trusted" relationship.
Thanks ivan for that explanation.. at first i easily took it for CQRS (Command Query Responsibility Segregation).. but a deeper look at the code suggested otherwise..
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@casualjim what is CORS.. and what are you trying to achieve exactly it appears to be configuring certian.. pros