Created
May 9, 2022 19:47
-
-
Save hamnis/3a2a30d136f7016b71ca16a5e443a179 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package reloadable | |
import cats.effect._ | |
import cats.syntax.all._ | |
import org.typelevel.log4cats.LoggerFactory | |
import java.io.ByteArrayInputStream | |
import java.net.Socket | |
import java.nio.charset.StandardCharsets | |
import java.security.KeyStore | |
import java.security.cert.{CertificateFactory, X509Certificate} | |
import java.util.UUID | |
import java.util.concurrent.atomic.AtomicReference | |
import javax.net.ssl.{SSLContext, SSLEngine, TrustManager, TrustManagerFactory, X509ExtendedTrustManager} | |
import scala.concurrent.duration.{DurationInt, FiniteDuration} | |
import scala.util.control.NonFatal | |
final class ReloadableX509TrustManager private ( | |
defaultTrustManager: X509ExtendedTrustManager, | |
trustManager: AtomicReference[Option[X509ExtendedTrustManager]] | |
) extends X509ExtendedTrustManager { | |
private[this] val logger = org.log4s.getLogger | |
override def checkClientTrusted(chain: Array[X509Certificate], authType: String): Unit = | |
runOp(_.checkClientTrusted(chain, authType), "no client trust cert found") | |
override def checkClientTrusted(chain: Array[X509Certificate], authType: String, socket: Socket): Unit = | |
runOp(_.checkClientTrusted(chain, authType, socket), "no client trust cert found") | |
override def checkClientTrusted(chain: Array[X509Certificate], authType: String, engine: SSLEngine): Unit = | |
runOp(_.checkClientTrusted(chain, authType, engine), "no client trust cert found") | |
override def checkServerTrusted(chain: Array[X509Certificate], authType: String): Unit = | |
runOp(_.checkServerTrusted(chain, authType), "no server trust cert found") | |
override def checkServerTrusted(chain: Array[X509Certificate], authType: String, socket: Socket): Unit = | |
runOp(_.checkServerTrusted(chain, authType, socket), "no server trust cert found") | |
override def checkServerTrusted(chain: Array[X509Certificate], authType: String, engine: SSLEngine): Unit = | |
runOp(_.checkServerTrusted(chain, authType, engine), "no server trust cert found") | |
override def getAcceptedIssuers: Array[X509Certificate] = | |
trustManager.get() match { | |
case Some(custom) => custom.getAcceptedIssuers ++ defaultTrustManager.getAcceptedIssuers | |
case None => defaultTrustManager.getAcceptedIssuers | |
} | |
private def runOp(f: (X509ExtendedTrustManager) => Unit, onError: String): Unit = { | |
trustManager.get() match { | |
case Some(custom) => | |
try { | |
f(custom) | |
} catch { | |
case e: Exception => | |
logger.warn(e)(s"$onError, trying default trust manager") | |
f(defaultTrustManager) | |
} | |
case None => f(defaultTrustManager) | |
} | |
} | |
} | |
object ReloadableX509TrustManager { | |
def SSLContextForResource[F[_]](certs: F[List[String]], duration: FiniteDuration = 1.minute)(implicit | |
A: Async[F], | |
S: Spawn[F], | |
loggerFactory: LoggerFactory[F] | |
) = | |
Resource.make(SSLContextFor(certs, duration))(_._2.cancel).map(_._1) | |
def SSLContextFor[F[_]](certs: F[List[String]], duration: FiniteDuration)(implicit | |
A: Async[F], | |
S: Spawn[F], | |
loggerFactory: LoggerFactory[F] | |
): F[(SSLContext, Fiber[F, Throwable, Unit])] = { | |
val trustManager = new AtomicReference[Option[X509ExtendedTrustManager]](None) | |
def make(newCerts: List[String], currentCerts: Ref[F, List[String]]) = { | |
currentCerts | |
.set(newCerts) | |
.flatMap(_ => | |
if (newCerts.isEmpty) { | |
loggerFactory.getLogger | |
.debug("No custom certificates, disabling custom trust manager") | |
.map(_ => trustManager.set(None)) | |
} else | |
makeTrustManager(newCerts).flatMap(cer => Sync[F].delay(trustManager.set(Some(cer)))) | |
) | |
}.recoverWith { | |
case NonFatal(e) => | |
loggerFactory.getLogger | |
.warn(e)("Exception raised while creating trustmanager") | |
} | |
def reload(currentCerts: Ref[F, List[String]]): F[Unit] = { | |
val op = for { | |
_ <- A.sleep(duration) | |
current <- currentCerts.get | |
newCerts <- certs | |
shouldReload = newCerts.isEmpty || current.isEmpty || current != newCerts | |
_ <- if (shouldReload) make(newCerts, currentCerts) else A.unit | |
_ <- reload(currentCerts) | |
} yield () | |
op.recoverWith { | |
case NonFatal(e) => | |
loggerFactory.getLogger(getClass).warn(e)("Exception raised while reloading") *> op | |
} | |
} | |
for { | |
ref <- Ref.of(List.empty[String]) | |
_ <- certs.flatMap(newCerts => make(newCerts, ref)) | |
fiber <- S.start(reload(ref)) | |
default <- getTrustManager(null) | |
ctx <- Sync[F].blocking { | |
val ctx = SSLContext.getInstance("TLS") | |
ctx.init( | |
null, | |
Array[TrustManager]( | |
new ReloadableX509TrustManager(default, trustManager) | |
), | |
null | |
) | |
ctx | |
} | |
} yield (ctx, fiber) | |
} | |
private def makeTrustManager[F[_]: Sync](additionalCerts: List[String]): F[X509ExtendedTrustManager] = { | |
for { | |
_ <- | |
LoggerFactory | |
.getLogger[F] | |
.info(s"Making new trustmanager from ${additionalCerts.size} custom certs") | |
ks <- keystoreFor(additionalCerts) | |
manager <- getTrustManager(ks) | |
} yield manager | |
} | |
private def getTrustManager[F[_]: Sync](keyStore: KeyStore) = { | |
Sync[F].blocking { | |
val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm) | |
tmf.init(keyStore) | |
tmf.getTrustManagers | |
.collectFirst { | |
case x: X509ExtendedTrustManager => x | |
} | |
.getOrElse(throw new IllegalStateException("No X509TrustManager in TrustManagerFactory")) | |
} | |
} | |
private def keystoreFor[F[_]: Sync](certificates: List[String]) = | |
Sync[F].blocking { | |
val ks = KeyStore.getInstance(KeyStore.getDefaultType) | |
ks.load(null) | |
val cf = CertificateFactory.getInstance("X.509") | |
certificates.foreach { (certificate) => | |
val cert = | |
cf.generateCertificate(new ByteArrayInputStream(certificate.getBytes(StandardCharsets.UTF_8))) | |
ks.setCertificateEntry(UUID.randomUUID.toString, cert) | |
} | |
ks | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment