Created
July 28, 2025 19:13
-
-
Save LionZXY/d837e5b34f3c9d44b120c27050b1e10a to your computer and use it in GitHub Desktop.
metro-factory
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// depends-on-plugin org.jetbrains.kotlin | |
// depends-on-plugin com.intellij.java.ide | |
import com.intellij.openapi.actionSystem.AnActionEvent | |
import com.intellij.openapi.command.WriteCommandAction | |
import com.intellij.psi.PsiFile | |
import liveplugin.editor | |
import liveplugin.psiFile | |
import liveplugin.registerAction | |
import liveplugin.show | |
import org.jetbrains.kotlin.com.intellij.openapi.editor.Document | |
import org.jetbrains.kotlin.com.intellij.psi.PsiDocumentManager | |
import org.jetbrains.kotlin.psi.KtClass | |
import org.jetbrains.kotlin.psi.KtFile | |
import org.jetbrains.kotlin.psi.KtPsiFactory | |
import org.jetbrains.kotlin.resolve.ImportPath | |
import com.intellij.openapi.project.Project | |
import org.jetbrains.kotlin.lexer.KtKeywordToken | |
import org.jetbrains.kotlin.lexer.KtModifierKeywordToken | |
import org.jetbrains.kotlin.lexer.KtTokens | |
import org.jetbrains.kotlin.psi.KtAnnotationEntry | |
import org.jetbrains.kotlin.psi.KtNamedFunction | |
/** | |
* Get one class with assisted and inject | |
*/ | |
fun getProcessedClass(psiFile: PsiFile): KtClass? { | |
val ktFile = psiFile as? KtFile ?: return null | |
val ktClasses = ktFile.declarations.filterIsInstance<KtClass>() | |
return ktClasses.find { isAssistedInjectClass(it) } | |
} | |
private fun isAssistedInjectClass(ktClass: KtClass): Boolean { | |
if (ktClass.annotationEntries.any { | |
it.shortName?.asString() == "Inject" | |
}.not() | |
) { | |
return false | |
} | |
val constructor = ktClass.primaryConstructor ?: return false | |
val parameters = constructor.valueParameters | |
for (param in parameters) { | |
if (param.annotationEntries.any { | |
it.shortName?.asString() == "Assisted" | |
} | |
) { | |
return true | |
} | |
} | |
return false | |
} | |
/** | |
* Generate factory | |
*/ | |
fun generateFactory(project: Project, ktClass: KtClass) { | |
val existedFactoryClass = getExistedFactoryClass(ktClass) | |
val psiFactory = KtPsiFactory(ktClass) | |
val body = ktClass.body ?: return | |
WriteCommandAction.runWriteCommandAction(project) { | |
addImport(psiFactory, ktClass) | |
if (existedFactoryClass == null) { | |
show("Generating new factory...") | |
val newClass = generateNewFactory(psiFactory, ktClass).getOrThrow() | |
body.addBefore(newClass, body.rBrace) | |
} else { | |
show("Replacing existed factory...") | |
transformExistedFactory(psiFactory, ktClass, existedFactoryClass).getOrThrow() | |
} | |
} | |
} | |
private fun getExistedFactoryClass(ktClass: KtClass): KtClass? { | |
val body = ktClass.body ?: return null | |
val innerClasses = body.declarations.filterIsInstance<KtClass>() | |
return innerClasses.find { isFactoryInnerClass(it) } | |
} | |
private fun isFactoryInnerClass(ktClass: KtClass): Boolean { | |
if (ktClass.annotationEntries.any { | |
it.shortName?.asString() == "Inject" | |
}.not() | |
) { | |
return false | |
} | |
return ktClass.name?.contains("Factory", ignoreCase = false) ?: false | |
} | |
private fun addImport(psiFactory: KtPsiFactory, ktClass: KtClass) { | |
val importDirective = psiFactory.createImportDirective(ImportPath.fromString("dev.zacsweers.metro.AssistedFactory")) | |
val importList = ktClass.containingKtFile.importList ?: return | |
val alreadyImported = importList.imports.any { it.importPath == importDirective.importPath } | |
if (!alreadyImported) { | |
importList.add(importDirective) | |
} | |
} | |
/** | |
* Generate new factory | |
*/ | |
fun generateNewFactory(psiFactory: KtPsiFactory, originalClass: KtClass): Result<KtClass> = runCatching { | |
val assistedParamsText = getAssistedParamsAsString(originalClass) | |
val newClass = psiFactory.createClass( | |
""" | |
@AssistedFactory | |
fun interface Factory { | |
fun create( | |
$assistedParamsText | |
): ${originalClass.name} | |
} | |
""".trimIndent() | |
) | |
return@runCatching newClass | |
} | |
private fun getAssistedParamsAsString(originalClass: KtClass): String { | |
val constructor = originalClass.primaryConstructor ?: error("Fail to get primary constructor") | |
val assistedParams = constructor.valueParameters | |
.filter { param -> | |
param.annotationEntries.any { | |
it.shortName?.asString() == "Assisted" | |
} | |
} | |
return assistedParams.joinToString(",\n") { param -> | |
"${param.name}: ${param.typeReference?.text}" | |
} | |
} | |
/** | |
* Transform existed factory: | |
* @Inject | |
* @ContributesBinding(AppGraph::class, FinishScreenDecomposeComponent.Factory::class) | |
* class Factory( | |
* private val factory: ( | |
* componentContext: ComponentContext | |
* ) -> FinishScreenDecomposeComponentImpl | |
* ) : FinishScreenDecomposeComponent.Factory { | |
* override fun invoke( | |
* componentContext: ComponentContext | |
* ) = factory(componentContext) | |
* } | |
* to: | |
* @AssistedFactory | |
* @ContributesBinding(AppGraph::class, FinishScreenDecomposeComponent.Factory::class) | |
* fun interface Factory : FinishScreenDecomposeComponent.Factory { | |
* override fun invoke( | |
* componentContext: ComponentContext | |
* ): FinishScreenDecomposeComponentImpl | |
* } | |
*/ | |
fun transformExistedFactory( | |
psiFactory: KtPsiFactory, | |
originalClass: KtClass, | |
existedFactory: KtClass | |
): Result<Unit> = runCatching { | |
existedFactory.primaryConstructor?.delete() | |
existedFactory.annotationEntries | |
.removeIf { it.shortName?.asString() == "Inject" } | |
val annotation = psiFactory.createAnnotationEntry("@AssistedFactory") | |
existedFactory.annotationEntries.add(annotation) | |
val functions = existedFactory.declarations.filterIsInstance<KtNamedFunction>() | |
if (functions.size > 1) { | |
error("In factory should be one function") | |
} | |
val factoryFunction = functions.first() | |
val className = originalClass.name ?: error("Fail to get original class name") | |
val reference = psiFactory.createType(className) | |
factoryFunction.bodyExpression?.delete() | |
factoryFunction.bodyBlockExpression?.delete() | |
factoryFunction.equalsToken?.delete() | |
factoryFunction.typeReference = reference | |
val classKeyword = existedFactory.node.findChildByType(KtTokens.CLASS_KEYWORD) | |
?: error("Fail to find class keyword") | |
val dummyInterfaceNode = psiFactory.createDeclaration<KtClass>("fun interface Dummy {}").node | |
val interfaceKeyword = dummyInterfaceNode.findChildByType(KtTokens.INTERFACE_KEYWORD) | |
?: error("Fail to create interface keyword") | |
classKeyword.treeParent.replaceChild(classKeyword, interfaceKeyword) | |
existedFactory.addModifier(KtTokens.FUN_KEYWORD) | |
} | |
private fun KtKeywordToken.toModifierToken(): KtModifierKeywordToken { | |
return KtModifierKeywordToken.keywordModifier(this.value, this.tokenId) | |
} | |
registerAction(id = "Generate Factory") { event: AnActionEvent -> | |
val project = event.project ?: return@registerAction | |
val psiFile = event.psiFile ?: return@registerAction | |
val ktClass = getProcessedClass(psiFile) ?: return@registerAction | |
show("Find class ${ktClass.name}, start processing") | |
generateFactory(project, ktClass) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment