Instantly share code, notes, and snippets.
Last active
June 5, 2017 10:50
-
Star
0
(0)
You must be signed in to star a gist -
Fork
0
(0)
You must be signed in to fork a gist
-
Save btd/695d0aac6fa1b8977fd3de65621d09f3 to your computer and use it in GitHub Desktop.
Filter to create cookie serialized sessions in servlets container
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package ingo.webapp | |
import javax.servlet._ | |
import javax.servlet.http._ | |
import java.util.Base64 | |
import java.io._ | |
import java.security.SecureRandom | |
import javax.crypto.Mac | |
import javax.crypto.spec.SecretKeySpec | |
import ingo.commons.util.KryoSerialization | |
case class CookieSessionFilterConfig(cookieName: String, cookieMaxAge: Int, signatureKey: String) | |
object CookieSessionFilter { | |
val base64decoder = Base64.getDecoder | |
val base64encoder = Base64.getEncoder.withoutPadding | |
val ID = "$$$" | |
private def signatureMethod(algorithm: String)(key: String, data: Array[Byte]): String = { | |
val mac = Mac.getInstance(algorithm) | |
val secretKey = new SecretKeySpec(key.getBytes, algorithm) | |
mac.init(secretKey) | |
base64encoder.encodeToString(mac.doFinal(data)) | |
} | |
val signature_HmacSHA256 = signatureMethod("HmacSHA256") _ | |
def deserializeCookieValue(value: String): java.util.HashMap[String, AnyRef] = { | |
val bytes = base64decoder.decode(value) | |
KryoSerialization.deserialize[java.util.HashMap[String, AnyRef]](bytes) | |
} | |
def serializeCookieValue(value: java.util.HashMap[String, AnyRef]): String = { | |
val bytes = KryoSerialization.serialize[java.util.HashMap[String, AnyRef]](value) | |
base64encoder.encodeToString(bytes) | |
} | |
def valueCookie(cookieName: String, attributes: java.util.HashMap[String, AnyRef]) = { | |
val value = serializeCookieValue(attributes) | |
val cookie = new Cookie(cookieName, value) | |
cookie.setPath("/") | |
cookie.setHttpOnly(true) | |
cookie | |
} | |
def valueCookieSignature(key: String, cookie: Cookie) = { | |
signature_HmacSHA256(key, (cookie.getName + "=" + cookie.getValue).getBytes) | |
} | |
def signatureCookie(key: String, _cookie: Cookie) = { | |
val sig = valueCookieSignature(key, _cookie) | |
val cookie = new Cookie(_cookie.getName + "_SIG", sig) | |
cookie.setPath("/") | |
cookie.setHttpOnly(true) | |
cookie | |
} | |
def cookieSignatureMatch(key: String, value: Cookie, signature: Cookie) = { | |
signature.getValue == valueCookieSignature(key, value) | |
} | |
def getSessionAttributes(req: HttpServletRequest, cookieName: String, key: String) = { | |
val _cookies = req.getCookies | |
val attributesOpt = for { | |
cookies <- Option(_cookies) | |
valueCookie <- cookies.find(_.getName == cookieName) | |
signatureCookie <- cookies.find(_.getName == cookieName + "_SIG") | |
if (cookieSignatureMatch(key, valueCookie, signatureCookie)) | |
} yield { | |
try { | |
deserializeCookieValue(valueCookie.getValue) | |
} catch { | |
case e: Exception => | |
new java.util.HashMap[String, AnyRef]() | |
} | |
} | |
attributesOpt.getOrElse(new java.util.HashMap[String, AnyRef]()) | |
} | |
} | |
class InGoHttpSession(servletContext: ServletContext, val attributes: java.util.HashMap[String, AnyRef]) | |
extends HttpSession { | |
private val _created = System.currentTimeMillis | |
private var _new = false | |
private var _id: String = updateId() | |
def updateId(): String = { | |
val idRaw = attributes.get(CookieSessionFilter.ID) | |
if (idRaw == null) { | |
_id = java.util.UUID.randomUUID.toString | |
_new = true | |
attributes.put(CookieSessionFilter.ID, _id) | |
} else { | |
_id = idRaw.asInstanceOf[String] | |
} | |
_id | |
} | |
private var maxAge = 60 | |
def getAttribute(name: String): Object = { | |
attributes.get(name) | |
} | |
def getAttributeNames(): java.util.Enumeration[String] = { | |
java.util.Collections.enumeration(attributes.keySet()) | |
} | |
def getId(): String = { | |
_id | |
} | |
def getValue(name: String): Object = getAttribute(name) | |
def getValueNames(): Array[String] = { | |
val keys = attributes.keySet | |
keys.toArray(new Array[String](keys.size)) | |
} | |
// invalidation creates | |
def invalidate(): Unit = { | |
attributes.clear | |
updateId() | |
} | |
def setAttribute(name: String, value: AnyRef): Unit = { | |
if (value != null) { | |
attributes.put(name, value) | |
} else { | |
attributes.remove(name) | |
} | |
} | |
def putValue(name: String, value: AnyRef): Unit = setAttribute(name, value) | |
def removeAttribute(name: String): Unit = { | |
attributes.remove(name) | |
} | |
def removeValue(name: String): Unit = removeAttribute(name) | |
def isNew(): Boolean = _new | |
def getCreationTime(): Long = _created | |
def getLastAccessedTime(): Long = _created | |
def getMaxInactiveInterval(): Int = maxAge | |
def setMaxInactiveInterval(value: Int): Unit = maxAge = value | |
def getServletContext(): ServletContext = servletContext | |
def getSessionContext(): HttpSessionContext = null | |
} | |
class SessionCookieRequestWrapper(req: HttpServletRequest, config: CookieSessionFilterConfig) | |
extends HttpServletRequestWrapper(req) { | |
var session: InGoHttpSession = null | |
def fillSessionFromCookie(): Unit = { | |
val attributes = CookieSessionFilter.getSessionAttributes(req, config.cookieName, config.signatureKey) | |
session = new InGoHttpSession(req.getServletContext, attributes) | |
session.setMaxInactiveInterval(config.cookieMaxAge) | |
} | |
override def changeSessionId(): String = session.updateId | |
override def getRequestedSessionId(): String = { | |
session.getId | |
} | |
override def getSession(): HttpSession = getSession(true) | |
// we do not need to use create parameter because if we pass false it means | |
// we do not need to create session | |
// but it is already created, so we still need to read it | |
override def getSession(create: Boolean): HttpSession = { | |
if (session == null) { | |
fillSessionFromCookie() | |
} | |
return session | |
} | |
override def isRequestedSessionIdFromCookie(): Boolean = false | |
override def isRequestedSessionIdFromUrl(): Boolean = false | |
override def isRequestedSessionIdFromURL(): Boolean = false | |
override def isRequestedSessionIdValid(): Boolean = true | |
} | |
class SessionCookieResponseWrapper(req: HttpServletRequest, | |
res: HttpServletResponse, | |
config: CookieSessionFilterConfig) | |
extends HttpServletResponseWrapper(res) { | |
val output = new WrappedServletOutputStream(res.getOutputStream()) | |
val writer = new PrintWriter(output, true) | |
override def getOutputStream() = output | |
override def getWriter() = writer | |
override def flushBuffer(): Unit = { | |
val session = req.getSession().asInstanceOf[InGoHttpSession] | |
val attributes = session.attributes | |
val valueCookie = CookieSessionFilter.valueCookie(config.cookieName, attributes) | |
val signatureCookie = CookieSessionFilter.signatureCookie(config.signatureKey, valueCookie) | |
valueCookie.setMaxAge(session.getMaxInactiveInterval()) | |
signatureCookie.setMaxAge(session.getMaxInactiveInterval()) | |
res.addCookie(valueCookie) | |
res.addCookie(signatureCookie) | |
writer.flush() | |
output.forwardBufferContent() | |
} | |
} | |
class WrappedServletOutputStream(_output: ServletOutputStream) extends ServletOutputStream { | |
val output = new ByteArrayOutputStream(100 * 1024) | |
var writeListener: WriteListener = null | |
def write(n: Int): Unit = { | |
output.write(n) | |
if (writeListener != null) writeListener.notify() | |
} | |
def forwardBufferContent(): Unit = { | |
output.writeTo(_output) | |
output.flush() | |
} | |
def setWriteListener(l: WriteListener): Unit = writeListener = l | |
def isReady(): Boolean = true | |
} | |
class CookieSessionFilter extends Filter { | |
var filterConfig = CookieSessionFilterConfig("SESS", 60, "EH4v4pLHeiQSplsWv3w6") | |
def destroy(): Unit = {} | |
def init(config: FilterConfig): Unit = { | |
for (value <- Option(config.getInitParameter("cookieMaxAge")).map(_.toInt)) { | |
filterConfig = filterConfig.copy(cookieMaxAge = value) | |
} | |
for (value <- Option(config.getInitParameter("cookieName"))) { | |
filterConfig = filterConfig.copy(cookieName = value) | |
} | |
for (value <- Option(config.getInitParameter("signatureKey"))) { | |
filterConfig = filterConfig.copy(signatureKey = value) | |
} | |
} | |
def doFilter(req: ServletRequest, res: ServletResponse, chain: FilterChain): Unit = { | |
val httpReq = req.asInstanceOf[HttpServletRequest] | |
val httpRes = res.asInstanceOf[HttpServletResponse] | |
val wrapperReq = new SessionCookieRequestWrapper(httpReq, filterConfig) | |
val wrapperRes = new SessionCookieResponseWrapper(wrapperReq, httpRes, filterConfig) | |
chain.doFilter(wrapperReq, wrapperRes) | |
wrapperRes.flushBuffer() | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment