Skip to content

Instantly share code, notes, and snippets.

@LionZXY
Created July 28, 2025 19:13
Show Gist options
  • Save LionZXY/d837e5b34f3c9d44b120c27050b1e10a to your computer and use it in GitHub Desktop.
Save LionZXY/d837e5b34f3c9d44b120c27050b1e10a to your computer and use it in GitHub Desktop.
metro-factory
// depends-on-plugin org.jetbrains.kotlin
// depends-on-plugin com.intellij.java.ide
import com.intellij.openapi.actionSystem.AnActionEvent
import com.intellij.openapi.command.WriteCommandAction
import com.intellij.psi.PsiFile
import liveplugin.editor
import liveplugin.psiFile
import liveplugin.registerAction
import liveplugin.show
import org.jetbrains.kotlin.com.intellij.openapi.editor.Document
import org.jetbrains.kotlin.com.intellij.psi.PsiDocumentManager
import org.jetbrains.kotlin.psi.KtClass
import org.jetbrains.kotlin.psi.KtFile
import org.jetbrains.kotlin.psi.KtPsiFactory
import org.jetbrains.kotlin.resolve.ImportPath
import com.intellij.openapi.project.Project
import org.jetbrains.kotlin.lexer.KtKeywordToken
import org.jetbrains.kotlin.lexer.KtModifierKeywordToken
import org.jetbrains.kotlin.lexer.KtTokens
import org.jetbrains.kotlin.psi.KtAnnotationEntry
import org.jetbrains.kotlin.psi.KtNamedFunction
/**
* Get one class with assisted and inject
*/
fun getProcessedClass(psiFile: PsiFile): KtClass? {
val ktFile = psiFile as? KtFile ?: return null
val ktClasses = ktFile.declarations.filterIsInstance<KtClass>()
return ktClasses.find { isAssistedInjectClass(it) }
}
private fun isAssistedInjectClass(ktClass: KtClass): Boolean {
if (ktClass.annotationEntries.any {
it.shortName?.asString() == "Inject"
}.not()
) {
return false
}
val constructor = ktClass.primaryConstructor ?: return false
val parameters = constructor.valueParameters
for (param in parameters) {
if (param.annotationEntries.any {
it.shortName?.asString() == "Assisted"
}
) {
return true
}
}
return false
}
/**
* Generate factory
*/
fun generateFactory(project: Project, ktClass: KtClass) {
val existedFactoryClass = getExistedFactoryClass(ktClass)
val psiFactory = KtPsiFactory(ktClass)
val body = ktClass.body ?: return
WriteCommandAction.runWriteCommandAction(project) {
addImport(psiFactory, ktClass)
if (existedFactoryClass == null) {
show("Generating new factory...")
val newClass = generateNewFactory(psiFactory, ktClass).getOrThrow()
body.addBefore(newClass, body.rBrace)
} else {
show("Replacing existed factory...")
transformExistedFactory(psiFactory, ktClass, existedFactoryClass).getOrThrow()
}
}
}
private fun getExistedFactoryClass(ktClass: KtClass): KtClass? {
val body = ktClass.body ?: return null
val innerClasses = body.declarations.filterIsInstance<KtClass>()
return innerClasses.find { isFactoryInnerClass(it) }
}
private fun isFactoryInnerClass(ktClass: KtClass): Boolean {
if (ktClass.annotationEntries.any {
it.shortName?.asString() == "Inject"
}.not()
) {
return false
}
return ktClass.name?.contains("Factory", ignoreCase = false) ?: false
}
private fun addImport(psiFactory: KtPsiFactory, ktClass: KtClass) {
val importDirective = psiFactory.createImportDirective(ImportPath.fromString("dev.zacsweers.metro.AssistedFactory"))
val importList = ktClass.containingKtFile.importList ?: return
val alreadyImported = importList.imports.any { it.importPath == importDirective.importPath }
if (!alreadyImported) {
importList.add(importDirective)
}
}
/**
* Generate new factory
*/
fun generateNewFactory(psiFactory: KtPsiFactory, originalClass: KtClass): Result<KtClass> = runCatching {
val assistedParamsText = getAssistedParamsAsString(originalClass)
val newClass = psiFactory.createClass(
"""
@AssistedFactory
fun interface Factory {
fun create(
$assistedParamsText
): ${originalClass.name}
}
""".trimIndent()
)
return@runCatching newClass
}
private fun getAssistedParamsAsString(originalClass: KtClass): String {
val constructor = originalClass.primaryConstructor ?: error("Fail to get primary constructor")
val assistedParams = constructor.valueParameters
.filter { param ->
param.annotationEntries.any {
it.shortName?.asString() == "Assisted"
}
}
return assistedParams.joinToString(",\n") { param ->
"${param.name}: ${param.typeReference?.text}"
}
}
/**
* Transform existed factory:
* @Inject
* @ContributesBinding(AppGraph::class, FinishScreenDecomposeComponent.Factory::class)
* class Factory(
* private val factory: (
* componentContext: ComponentContext
* ) -> FinishScreenDecomposeComponentImpl
* ) : FinishScreenDecomposeComponent.Factory {
* override fun invoke(
* componentContext: ComponentContext
* ) = factory(componentContext)
* }
* to:
* @AssistedFactory
* @ContributesBinding(AppGraph::class, FinishScreenDecomposeComponent.Factory::class)
* fun interface Factory : FinishScreenDecomposeComponent.Factory {
* override fun invoke(
* componentContext: ComponentContext
* ): FinishScreenDecomposeComponentImpl
* }
*/
fun transformExistedFactory(
psiFactory: KtPsiFactory,
originalClass: KtClass,
existedFactory: KtClass
): Result<Unit> = runCatching {
existedFactory.primaryConstructor?.delete()
existedFactory.annotationEntries
.removeIf { it.shortName?.asString() == "Inject" }
val annotation = psiFactory.createAnnotationEntry("@AssistedFactory")
existedFactory.annotationEntries.add(annotation)
val functions = existedFactory.declarations.filterIsInstance<KtNamedFunction>()
if (functions.size > 1) {
error("In factory should be one function")
}
val factoryFunction = functions.first()
val className = originalClass.name ?: error("Fail to get original class name")
val reference = psiFactory.createType(className)
factoryFunction.bodyExpression?.delete()
factoryFunction.bodyBlockExpression?.delete()
factoryFunction.equalsToken?.delete()
factoryFunction.typeReference = reference
val classKeyword = existedFactory.node.findChildByType(KtTokens.CLASS_KEYWORD)
?: error("Fail to find class keyword")
val dummyInterfaceNode = psiFactory.createDeclaration<KtClass>("fun interface Dummy {}").node
val interfaceKeyword = dummyInterfaceNode.findChildByType(KtTokens.INTERFACE_KEYWORD)
?: error("Fail to create interface keyword")
classKeyword.treeParent.replaceChild(classKeyword, interfaceKeyword)
existedFactory.addModifier(KtTokens.FUN_KEYWORD)
}
private fun KtKeywordToken.toModifierToken(): KtModifierKeywordToken {
return KtModifierKeywordToken.keywordModifier(this.value, this.tokenId)
}
registerAction(id = "Generate Factory") { event: AnActionEvent ->
val project = event.project ?: return@registerAction
val psiFile = event.psiFile ?: return@registerAction
val ktClass = getProcessedClass(psiFile) ?: return@registerAction
show("Find class ${ktClass.name}, start processing")
generateFactory(project, ktClass)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment