Skip to content

Instantly share code, notes, and snippets.

@samspills
Created April 13, 2023 20:26
Show Gist options
  • Save samspills/fba4be59de41c73262bc23d7939e977a to your computer and use it in GitHub Desktop.
Save samspills/fba4be59de41c73262bc23d7939e977a to your computer and use it in GitHub Desktop.
toy example that runs a mock JWKS server and demonstrates decoding a JWT token using a public jwk
//> using scala "2.13.10"
//> using lib "org.typelevel::cats-effect:3.4.8"
//> using lib "org.http4s::http4s-ember-client:0.23.18"
//> using lib "org.http4s::http4s-ember-server:0.23.18"
//> using lib "org.http4s::http4s-server:0.23.18"
//> using lib "org.http4s::http4s-dsl:0.23.18"
//> using lib "org.http4s::http4s-circe:0.23.18"
//> using lib "io.circe::circe-core:0.14.5"
//> using lib "io.circe::circe-generic:0.14.5"
//> using lib "org.scodec::scodec-bits:1.1.37"
//> using lib "co.fs2::fs2-core:3.6.1"
//> using lib "com.github.jwt-scala::jwt-core:9.2.0"
//> using lib "io.circe::circe-fs2:0.14.1"
import java.math.BigInteger
import java.security.{KeyFactory, PublicKey}
import java.security.spec.RSAPublicKeySpec
import cats.effect.{IO, IOApp}
import cats.syntax.all._
import com.comcast.ip4s._
import fs2.{Fallible, Stream}
import io.circe.{Decoder, Encoder}
import io.circe.fs2.{byteStreamParser, decoder}
import io.circe.generic.semiauto._
import io.circe.syntax._
import org.http4s._
import org.http4s.circe._
import org.http4s.client.Client
import org.http4s.dsl.io._
import org.http4s.ember.client.EmberClientBuilder
import org.http4s.ember.server._
import org.http4s.implicits._
import pdi.jwt.{Jwt, JwtAlgorithm}
import scodec.bits.Bases
case class JWKS(keys: List[Key]) {
val keyMap: IO[Map[String, PublicKey]] =
keys
.map(k => k.publicKey.map(pk => k.kid -> pk))
.traverse(IO.fromEither(_))
.map(Map.from)
}
object JWKS {
implicit val encoderJWKS: Encoder[JWKS] = deriveEncoder[JWKS]
implicit val decoderJWKS: Decoder[JWKS] = deriveDecoder[JWKS]
implicit val entityDecoder = jsonOf[IO, JWKS]
}
case class Key(kty: String, n: String, e: String, kid: String) {
def modulus: Either[Throwable, BigInteger] = {
val decoded =
Stream(n)
.through(
fs2.text.base64
.decodeWithAlphabet[Fallible](Bases.Alphabets.Base64Url)
)
.compile
.to(Array)
decoded.map(d => new BigInteger(1, d))
}
def exponent: Either[Throwable, BigInteger] = {
val decoded =
Stream(e)
.through(
fs2.text.base64
.decodeWithAlphabet[Fallible](Bases.Alphabets.Base64Url)
)
.compile
.to(Array)
decoded.map(d => new BigInteger(1, d))
}
def publicKey: Either[Throwable, PublicKey] = (modulus, exponent).tupled.map {
case (m, e) =>
KeyFactory.getInstance("RSA").generatePublic(new RSAPublicKeySpec(m, e))
}
}
object Key {
implicit val encoderKey: Encoder[Key] = deriveEncoder[Key]
implicit val decoderKey: Decoder[Key] = deriveDecoder[Key]
}
case class TokenHeader(kid: String, alg: String)
object TokenHeader {
implicit val decoderTokenHeader: Decoder[TokenHeader] =
deriveDecoder[TokenHeader]
}
case class Token(raw: String) {
private val split = raw.split('.')
val header: IO[TokenHeader] = Stream[IO, String](split.head)
.through(fs2.text.base64.decodeWithAlphabet(Bases.Alphabets.Base64Url))
.through(byteStreamParser)
.through(decoder[IO, TokenHeader])
.compile
.lastOrError
}
object JwksService {
def jwks = JWKS(
List(
Key(
"RSA",
"u1SU1LfVLPHCozMxH2Mo4lgOEePzNm0tRgeLezV6ffAt0gunVTLw7onLRnrq0_IzW7yWR7QkrmBL7jTKEn5u-qKhbwKfBstIs-bMY2Zkp18gnTxKLxoS2tFczGkPLPgizskuemMghRniWaoLcyehkd3qqGElvW_VDL5AaWTg0nLVkjRo9z-40RQzuVaE8AkAFmxZzow3x-VJYKdjykkJ0iT9wCS0DRTXu269V264Vf_3jvredZiKRkgwlL9xNAwxXFg0x_XFw005UWVRIkdgcKWTjpBP2dPwVZ4WWC-9aGVd-Gyn1o0CLelf4rEjGoXbAAEgAqeGUxrcIlbjXfbcmw",
"AQAB",
"123"
)
)
)
def routes = HttpRoutes.of[IO] { case GET -> Root / "jwks.json" =>
Ok(jwks.asJson)
}
val server = EmberServerBuilder
.default[IO]
.withHost(ipv4"0.0.0.0")
.withPort(port"8080")
.withHttpApp(routes.orNotFound)
.build
}
object JWKSApp extends IOApp.Simple {
def getJWKS(client: Client[IO]): IO[Map[String, PublicKey]] =
client
.expect[JWKS]("http://localhost:8080/jwks.json")
.flatMap(jwks => jwks.keyMap)
val token = Token(
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6IjEyMyJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IlNhbSBQaWxsc3dvcnRoIiwiZW1haWwiOiJjb21tZW50c0BibGVyZi5jYSIsImFkbWluIjp0cnVlLCJpYXQiOjE1MTYyMzkwMjJ9.Y8BgOshApJeQWE4H5na3YzLvgUd48YPgfg7zWwEdDaVcwX19xS3owHolMLXj9wwSH2xhsUF-hXfuYJ4zWik64tUbzWV3_Z6EFKG8AXrZhJDWa0N8XH4Ttkc-gPXv9dKHGrq955dkNw95KKav-VOaia81MaYH2mWR_lrOvyFiV06ggM--1Alb3Vedh5yxDegc1jYTgvzP7lD9Fq3pFGQ0voXLZ_l6MxHtsvpzi8y3daD1YYXGnmvuy3JMATN917RAukiH5PkfT7OmsSCay4w8Yrb_VKCN9MX1GX2T2IFMK-5GwPJ6cebphvkfSi_uqn58Eg7JKVsXo_m11iL_-EgCaA"
)
val run: IO[Unit] = JwksService.server.use { _ =>
EmberClientBuilder
.default[IO]
.build
.use(client => getJWKS(client))
.flatMap { keys =>
val key = token.header.map(h => keys.get(h.kid))
val jwt = key.flatMap { k =>
k match {
case None => IO.raiseError(new Exception("ruh roh"))
case Some(k) =>
IO.fromTry(Jwt.decodeAll(token.raw, k, Seq(JwtAlgorithm.RS256)))
}
}
jwt.flatMap(IO.println)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment