Created
February 9, 2024 14:38
-
-
Save cptwunderlich/8cbb9ae09b0d7cabdcd4a8b72183c363 to your computer and use it in GitHub Desktop.
scalafix rule for migrating Tapir's DecodeFailureHandler from v1.8 to v1.9
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package v2_5 | |
import scalafix.v1._ | |
import scala.meta._ | |
/* Type of sttp.tapir FailureHandler has changed. | |
Now it takes a type parameter for the effect type. We need to add that and | |
change the return type of the apply method. | |
*/ | |
class FailureHandler extends SyntacticRule("FailureHandler") { | |
override def isRewrite: Boolean = true | |
private def hasImport(tree: Tree, importName: String): Boolean = { | |
tree | |
.collect { case Import(importers) => | |
importers.exists(_.importees.exists { | |
case Importee.Name(Name.Indeterminate(name)) if name == importName => true | |
case _ => false | |
}) | |
} | |
.foldLeft(false)(_ || _) | |
} | |
override def fix(implicit doc: SyntacticDocument): Patch = { | |
val imports = Patch.addGlobalImport(importer"sttp.monad.MonadError") + | |
(if (hasImport(doc.tree, "Future")) Patch.empty | |
else Patch.addGlobalImport(importer"scala.concurrent.Future")) + | |
(if (hasImport(doc.tree, "ExecutionContext")) Patch.empty | |
else Patch.addGlobalImport(importer"scala.concurrent.ExecutionContext")) | |
val clazz = doc.tree.collect { | |
case c @ Defn.Class.After_4_6_0( | |
_, | |
name, | |
params, | |
Ctor.Primary.After_4_6_0(_, _, ctorParams), | |
Template.After_4_4_0(_, | |
List(Init.After_4_6_0(Type.Name("DecodeFailureHandler"), Name.Anonymous(), Nil)), | |
_, | |
stats, | |
_)) if params.values.isEmpty && ctorParams.isEmpty => | |
val tokens = c.tokens | |
val identifier = tokens.find(_.is[Token.Ident]) | |
val traitName = tokens.find { case Token.Ident("DecodeFailureHandler") => true; case _ => false } | |
Patch.fromIterable( | |
Seq( | |
Patch.replaceToken(identifier.get, s"${name.value}()(implicit ec: ExecutionContext)"), | |
Patch.addRight(traitName.get, "[Future]"), | |
stats.collect { case m @ Defn.Def.After_4_7_3(mods, Term.Name("apply"), _, Some(tpe), body) => | |
val modsStr = mods.mkString(" ") | |
Patch.replaceTree( | |
m, | |
s"$modsStr def apply(ctx: DecodeFailureContext)(implicit monad: MonadError[Future]): Future[$tpe] = Future $body") | |
}.asPatch | |
)) | |
}.asPatch | |
if (clazz.isEmpty) Patch.empty else (imports + clazz) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment