Created
April 13, 2023 20:26
-
-
Save samspills/fba4be59de41c73262bc23d7939e977a to your computer and use it in GitHub Desktop.
toy example that runs a mock JWKS server and demonstrates decoding a JWT token using a public jwk
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//> using scala "2.13.10" | |
//> using lib "org.typelevel::cats-effect:3.4.8" | |
//> using lib "org.http4s::http4s-ember-client:0.23.18" | |
//> using lib "org.http4s::http4s-ember-server:0.23.18" | |
//> using lib "org.http4s::http4s-server:0.23.18" | |
//> using lib "org.http4s::http4s-dsl:0.23.18" | |
//> using lib "org.http4s::http4s-circe:0.23.18" | |
//> using lib "io.circe::circe-core:0.14.5" | |
//> using lib "io.circe::circe-generic:0.14.5" | |
//> using lib "org.scodec::scodec-bits:1.1.37" | |
//> using lib "co.fs2::fs2-core:3.6.1" | |
//> using lib "com.github.jwt-scala::jwt-core:9.2.0" | |
//> using lib "io.circe::circe-fs2:0.14.1" | |
import java.math.BigInteger | |
import java.security.{KeyFactory, PublicKey} | |
import java.security.spec.RSAPublicKeySpec | |
import cats.effect.{IO, IOApp} | |
import cats.syntax.all._ | |
import com.comcast.ip4s._ | |
import fs2.{Fallible, Stream} | |
import io.circe.{Decoder, Encoder} | |
import io.circe.fs2.{byteStreamParser, decoder} | |
import io.circe.generic.semiauto._ | |
import io.circe.syntax._ | |
import org.http4s._ | |
import org.http4s.circe._ | |
import org.http4s.client.Client | |
import org.http4s.dsl.io._ | |
import org.http4s.ember.client.EmberClientBuilder | |
import org.http4s.ember.server._ | |
import org.http4s.implicits._ | |
import pdi.jwt.{Jwt, JwtAlgorithm} | |
import scodec.bits.Bases | |
case class JWKS(keys: List[Key]) { | |
val keyMap: IO[Map[String, PublicKey]] = | |
keys | |
.map(k => k.publicKey.map(pk => k.kid -> pk)) | |
.traverse(IO.fromEither(_)) | |
.map(Map.from) | |
} | |
object JWKS { | |
implicit val encoderJWKS: Encoder[JWKS] = deriveEncoder[JWKS] | |
implicit val decoderJWKS: Decoder[JWKS] = deriveDecoder[JWKS] | |
implicit val entityDecoder = jsonOf[IO, JWKS] | |
} | |
case class Key(kty: String, n: String, e: String, kid: String) { | |
def modulus: Either[Throwable, BigInteger] = { | |
val decoded = | |
Stream(n) | |
.through( | |
fs2.text.base64 | |
.decodeWithAlphabet[Fallible](Bases.Alphabets.Base64Url) | |
) | |
.compile | |
.to(Array) | |
decoded.map(d => new BigInteger(1, d)) | |
} | |
def exponent: Either[Throwable, BigInteger] = { | |
val decoded = | |
Stream(e) | |
.through( | |
fs2.text.base64 | |
.decodeWithAlphabet[Fallible](Bases.Alphabets.Base64Url) | |
) | |
.compile | |
.to(Array) | |
decoded.map(d => new BigInteger(1, d)) | |
} | |
def publicKey: Either[Throwable, PublicKey] = (modulus, exponent).tupled.map { | |
case (m, e) => | |
KeyFactory.getInstance("RSA").generatePublic(new RSAPublicKeySpec(m, e)) | |
} | |
} | |
object Key { | |
implicit val encoderKey: Encoder[Key] = deriveEncoder[Key] | |
implicit val decoderKey: Decoder[Key] = deriveDecoder[Key] | |
} | |
case class TokenHeader(kid: String, alg: String) | |
object TokenHeader { | |
implicit val decoderTokenHeader: Decoder[TokenHeader] = | |
deriveDecoder[TokenHeader] | |
} | |
case class Token(raw: String) { | |
private val split = raw.split('.') | |
val header: IO[TokenHeader] = Stream[IO, String](split.head) | |
.through(fs2.text.base64.decodeWithAlphabet(Bases.Alphabets.Base64Url)) | |
.through(byteStreamParser) | |
.through(decoder[IO, TokenHeader]) | |
.compile | |
.lastOrError | |
} | |
object JwksService { | |
def jwks = JWKS( | |
List( | |
Key( | |
"RSA", | |
"u1SU1LfVLPHCozMxH2Mo4lgOEePzNm0tRgeLezV6ffAt0gunVTLw7onLRnrq0_IzW7yWR7QkrmBL7jTKEn5u-qKhbwKfBstIs-bMY2Zkp18gnTxKLxoS2tFczGkPLPgizskuemMghRniWaoLcyehkd3qqGElvW_VDL5AaWTg0nLVkjRo9z-40RQzuVaE8AkAFmxZzow3x-VJYKdjykkJ0iT9wCS0DRTXu269V264Vf_3jvredZiKRkgwlL9xNAwxXFg0x_XFw005UWVRIkdgcKWTjpBP2dPwVZ4WWC-9aGVd-Gyn1o0CLelf4rEjGoXbAAEgAqeGUxrcIlbjXfbcmw", | |
"AQAB", | |
"123" | |
) | |
) | |
) | |
def routes = HttpRoutes.of[IO] { case GET -> Root / "jwks.json" => | |
Ok(jwks.asJson) | |
} | |
val server = EmberServerBuilder | |
.default[IO] | |
.withHost(ipv4"0.0.0.0") | |
.withPort(port"8080") | |
.withHttpApp(routes.orNotFound) | |
.build | |
} | |
object JWKSApp extends IOApp.Simple { | |
def getJWKS(client: Client[IO]): IO[Map[String, PublicKey]] = | |
client | |
.expect[JWKS]("http://localhost:8080/jwks.json") | |
.flatMap(jwks => jwks.keyMap) | |
val token = Token( | |
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6IjEyMyJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IlNhbSBQaWxsc3dvcnRoIiwiZW1haWwiOiJjb21tZW50c0BibGVyZi5jYSIsImFkbWluIjp0cnVlLCJpYXQiOjE1MTYyMzkwMjJ9.Y8BgOshApJeQWE4H5na3YzLvgUd48YPgfg7zWwEdDaVcwX19xS3owHolMLXj9wwSH2xhsUF-hXfuYJ4zWik64tUbzWV3_Z6EFKG8AXrZhJDWa0N8XH4Ttkc-gPXv9dKHGrq955dkNw95KKav-VOaia81MaYH2mWR_lrOvyFiV06ggM--1Alb3Vedh5yxDegc1jYTgvzP7lD9Fq3pFGQ0voXLZ_l6MxHtsvpzi8y3daD1YYXGnmvuy3JMATN917RAukiH5PkfT7OmsSCay4w8Yrb_VKCN9MX1GX2T2IFMK-5GwPJ6cebphvkfSi_uqn58Eg7JKVsXo_m11iL_-EgCaA" | |
) | |
val run: IO[Unit] = JwksService.server.use { _ => | |
EmberClientBuilder | |
.default[IO] | |
.build | |
.use(client => getJWKS(client)) | |
.flatMap { keys => | |
val key = token.header.map(h => keys.get(h.kid)) | |
val jwt = key.flatMap { k => | |
k match { | |
case None => IO.raiseError(new Exception("ruh roh")) | |
case Some(k) => | |
IO.fromTry(Jwt.decodeAll(token.raw, k, Seq(JwtAlgorithm.RS256))) | |
} | |
} | |
jwt.flatMap(IO.println) | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment