Skip to content

Fix/mixin fixes #2470

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -953,19 +953,23 @@ object MEExpressionCompletionUtil {
isStore: Boolean,
mixinClass: PsiClass,
): List<EliminableLookup> {
val isStatic = targetMethod.hasAccess(Opcodes.ACC_STATIC)
// ignore "this"
if (!targetMethod.hasAccess(Opcodes.ACC_STATIC) && index == 0) {
if (!isStatic && index == 0) {
return emptyList()
}

var argumentsSize = Type.getArgumentsAndReturnSizes(targetMethod.desc) shr 2
if (targetMethod.hasAccess(Opcodes.ACC_STATIC)) {
if (isStatic) {
argumentsSize--
}
val isArgsOnly = index < argumentsSize

if (targetMethod.localVariables != null) {
val localsHere = targetMethod.localVariables.filter { localVariable ->
if (!isStatic && localVariable.index == 0) {
return@filter false
}
val firstValidInstruction = if (isStore) {
generateSequence<AbstractInsnNode>(localVariable.start) { it.previous }
.firstOrNull { it.opcode >= 0 }
Expand Down Expand Up @@ -1012,7 +1016,11 @@ object MEExpressionCompletionUtil {

// fallback to ASM dataflow
val localTypes = AsmDfaUtil.getLocalVariableTypes(project, targetClass, targetMethod, originalInsn)
?.toMutableList()
?: return emptyList()
if (!isStatic) {
localTypes[0] = null
}
val localType = localTypes.getOrNull(index) ?: return emptyList()
val ordinal = localTypes.asSequence().take(index).filter { it == localType }.count()
val localName = localType.typeNameToInsert().replace("[]", "Array") + (ordinal + 1)
Expand Down
32 changes: 17 additions & 15 deletions src/main/kotlin/platform/mixin/expression/MEExpressionInjector.kt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ package com.demonwav.mcdev.platform.mixin.expression
import com.demonwav.mcdev.platform.mixin.util.MixinConstants
import com.demonwav.mcdev.util.findContainingModifierList
import com.demonwav.mcdev.util.findContainingNameValuePair
import com.demonwav.mcdev.util.parseArray
import com.intellij.lang.injection.InjectedLanguageManager
import com.intellij.lang.injection.MultiHostInjector
import com.intellij.lang.injection.MultiHostRegistrar
Expand Down Expand Up @@ -117,23 +118,24 @@ class MEExpressionInjector : MultiHostInjector {
}
}
} else if (annotation.hasQualifiedName(MixinConstants.MixinExtras.EXPRESSION)) {
val valueExpr = annotation.findDeclaredAttributeValue("value") ?: continue
val places = mutableListOf<Pair<PsiLanguageInjectionHost, TextRange>>()
iterateConcatenation(valueExpr) { op ->
if (op is PsiLanguageInjectionHost) {
for (textRange in getTextRanges(op)) {
places += op to textRange
for (valueExpr in annotation.findDeclaredAttributeValue("value")?.parseArray { it }.orEmpty()) {
val places = mutableListOf<Pair<PsiLanguageInjectionHost, TextRange>>()
iterateConcatenation(valueExpr) { op ->
if (op is PsiLanguageInjectionHost) {
for (textRange in getTextRanges(op)) {
places += op to textRange
}
} else {
isFrankenstein = true
}
} else {
isFrankenstein = true
}
}
if (places.isNotEmpty()) {
for ((i, place) in places.withIndex()) {
val (host, range) = place
val prefix = "\ndo { ".takeIf { i == 0 }
val suffix = " }".takeIf { i == places.size - 1 }
registrar.addPlace(prefix, suffix, host, range)
if (places.isNotEmpty()) {
for ((i, place) in places.withIndex()) {
val (host, range) = place
val prefix = "\ndo { ".takeIf { i == 0 }
val suffix = " }".takeIf { i == places.size - 1 }
registrar.addPlace(prefix, suffix, host, range)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,10 @@ object MEExpressionMatchUtil {
physicalInsn
}

val unfilteredLocals = localInfo.getLocals(module, targetClass, targetMethod, actualInsn)
?: return@addMember false
val filteredLocals = localInfo.matchLocals(unfilteredLocals, CollectVisitor.Mode.MATCH_ALL)
val filteredLocals = localInfo.matchLocals(
module, targetClass, targetMethod, actualInsn,
CollectVisitor.Mode.MATCH_ALL
) ?: return@addMember false
filteredLocals.any { it.index == virtualInsn.`var` }
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@ class ModifyVariableHandler : InjectorAnnotationHandler() {

val possibleTypes = mutableSetOf<String>()
for (insn in targets) {
val locals = info.getLocals(module, targetClass, targetMethod, insn.insn) ?: continue
val matchedLocals = info.matchLocals(locals, CollectVisitor.Mode.COMPLETION, matchType = false)
val matchedLocals = info.matchLocals(
module, targetClass, targetMethod, insn.insn,
CollectVisitor.Mode.COMPLETION, matchType = false
) ?: continue
for (local in matchedLocals) {
possibleTypes += local.desc!!
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,11 @@ abstract class AbstractLoadInjectionPoint(private val store: Boolean) : Injectio
}

val shiftedInsn = if (store) insn.next ?: insn else insn
val locals = info.getLocals(module, targetClass, methodNode, shiftedInsn) ?: continue
val locals = info.matchLocals(module, targetClass, methodNode, shiftedInsn, mode) ?: continue

val elementFactory = JavaPsiFacade.getElementFactory(module.project)

for (result in info.matchLocals(locals, mode)) {
for (result in locals) {
addResult(shiftedInsn, elementFactory.createExpressionFromText(result.name, null))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,21 +188,18 @@ class ExpressionInjectionPoint : InjectionPoint<PsiElement>() {
(exprAnnotation.findDeclaredAttributeValue("id")?.constantStringValue ?: "") == atId
}
.flatMap { exprAnnotation ->
val expressionElements = exprAnnotation.findDeclaredAttributeValue("value")?.parseArray { it }
?: return@flatMap emptySequence<Pair<Expression, MEStatement>>()
expressionElements.asSequence().mapNotNull { expressionElement ->
val text = expressionElement.constantStringValue ?: return@mapNotNull null
val rootStatementPsi = InjectedLanguageManager.getInstance(project)
.getInjectedPsiFiles(expressionElement)?.firstOrNull()
?.let {
(it.first as? MEExpressionFile)?.statements?.firstOrNull { stmt ->
stmt.findMultiInjectionHost()?.parentOfType<PsiAnnotation>() == exprAnnotation
}
}
?: project.meExpressionElementFactory.createFile("do {$text}").statements.singleOrNull()
?: project.meExpressionElementFactory.createStatement("empty")
MEExpressionMatchUtil.createExpression(text)?.let { it to rootStatementPsi }
}
exprAnnotation.findDeclaredAttributeValue("value")?.parseArray { it }.orEmpty()
}
.mapIndexedNotNull { exprIndex, expressionElement ->
val text = expressionElement.constantStringValue ?: return@mapIndexedNotNull null
val rootStatementPsi = InjectedLanguageManager.getInstance(project)
.getInjectedPsiFiles(expressionElement)?.firstOrNull()
?.let {
(it.first as? MEExpressionFile)?.statements?.getOrNull(exprIndex)
}
?: project.meExpressionElementFactory.createFile("do {$text}").statements.singleOrNull()
?: project.meExpressionElementFactory.createStatement("empty")
MEExpressionMatchUtil.createExpression(text)?.let { it to rootStatementPsi }
}
.toList()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,10 @@ class ModifyVariableArgsOnlyInspection : MixinInspection() {

val wantedDesc = wantedType.descriptor

for ((targetClass, targetMethod) in methodTargets) {
for ((_, targetMethod) in methodTargets) {
val argTypes = mutableListOf<String?>()
if (!targetMethod.hasAccess(Opcodes.ACC_STATIC)) {
argTypes += "L${targetClass.name};"
argTypes += null
}
for (arg in Type.getArgumentTypes(targetMethod.desc)) {
argTypes += arg.descriptor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,10 @@ class UnresolvedLocalCaptureInspection : MixinInspection() {
val localInfo = LocalInfo.fromAnnotation(parameter.type.unwrapLocalRef(), localAnnotation)

for (target in targets) {
val locals = localInfo.getLocals(module, target.method.clazz, target.method.method, target.result.insn)
?: continue
val matchingLocals = localInfo.matchLocals(locals, CollectVisitor.Mode.MATCH_ALL)
val matchingLocals = localInfo.matchLocals(
module, target.method.clazz, target.method.method, target.result.insn,
CollectVisitor.Mode.MATCH_ALL
) ?: continue
if (matchingLocals.size != 1) {
holder.registerProblem(
localAnnotation.nameReferenceElement ?: localAnnotation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import com.demonwav.mcdev.platform.mixin.inspection.MixinInspection
import com.demonwav.mcdev.platform.mixin.util.MixinConstants
import com.demonwav.mcdev.platform.mixin.util.isAssignable
import com.demonwav.mcdev.util.McdevDfaUtil
import com.demonwav.mcdev.util.toObjectType
import com.intellij.codeInsight.intention.FileModifier.SafeFieldForPreview
import com.intellij.codeInspection.LocalQuickFixOnPsiElement
import com.intellij.codeInspection.ProblemsHolder
Expand Down Expand Up @@ -52,6 +53,7 @@ class WrongOperationParametersInspection : MixinInspection() {
if (expression.resolveMethod()?.containingClass?.qualifiedName != MixinConstants.MixinExtras.OPERATION) {
return
}
val project = expression.project

val containingMethod = PsiTreeUtil.getParentOfType(
expression,
Expand Down Expand Up @@ -85,7 +87,7 @@ class WrongOperationParametersInspection : MixinInspection() {
if (expression.argumentList.expressionCount == expectedParamTypes.size) {
val allValid = expression.argumentList.expressions.zip(expectedParamTypes).all { (expr, expectedType) ->
val exprType = McdevDfaUtil.getDataflowType(expr) ?: return@all true
isAssignable(expectedType, exprType, false)
isAssignable(expectedType.toObjectType(project), exprType.toObjectType(project))
}
if (allValid) {
return
Expand Down
12 changes: 9 additions & 3 deletions src/main/kotlin/platform/mixin/util/LocalInfo.kt
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class LocalInfo(
val ordinal: Int?,
val names: Set<String>,
) {
fun getLocals(
private fun getLocals(
module: Module,
targetClass: ClassNode,
methodNode: MethodNode,
Expand All @@ -69,10 +69,16 @@ class LocalInfo(
}

fun matchLocals(
locals: Array<LocalVariables.LocalVariable?>,
module: Module,
targetClass: ClassNode,
methodNode: MethodNode,
insn: AbstractInsnNode,
mode: CollectVisitor.Mode,
matchType: Boolean = true,
): List<LocalVariables.LocalVariable> {
): List<LocalVariables.LocalVariable>? {
val locals = getLocals(module, targetClass, methodNode, insn)
?.drop(if (methodNode.hasAccess(Opcodes.ACC_STATIC)) 0 else 1)
?: return null
val typeDesc = type?.descriptor
if (ordinal != null) {
val ordinals = mutableMapOf<String, Int>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ class SourceCodeLocationInfo(val index: Int, val lineNumber: Int?, val indexInLi
}

if (count++ == index) {
myResult = t
if (lineNumber == null) {
myResult = t
return true
}
}
Expand Down
12 changes: 12 additions & 0 deletions src/main/kotlin/util/psi-utils.kt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import com.intellij.openapi.editor.Document
import com.intellij.openapi.module.Module
import com.intellij.openapi.module.ModuleManager
import com.intellij.openapi.module.ModuleUtilCore
import com.intellij.openapi.project.Project
import com.intellij.openapi.roots.LibraryOrderEntry
import com.intellij.openapi.roots.ModuleRootManager
import com.intellij.openapi.roots.ProjectFileIndex
Expand All @@ -53,6 +54,7 @@ import com.intellij.psi.PsiExpression
import com.intellij.psi.PsiFile
import com.intellij.psi.PsiKeyword
import com.intellij.psi.PsiLanguageInjectionHost
import com.intellij.psi.PsiManager
import com.intellij.psi.PsiMember
import com.intellij.psi.PsiMethod
import com.intellij.psi.PsiMethodReferenceExpression
Expand All @@ -62,11 +64,13 @@ import com.intellij.psi.PsiModifierList
import com.intellij.psi.PsiNameValuePair
import com.intellij.psi.PsiParameter
import com.intellij.psi.PsiParameterList
import com.intellij.psi.PsiPrimitiveType
import com.intellij.psi.PsiReference
import com.intellij.psi.PsiReferenceExpression
import com.intellij.psi.PsiType
import com.intellij.psi.ResolveResult
import com.intellij.psi.filters.ElementFilter
import com.intellij.psi.search.GlobalSearchScope
import com.intellij.psi.util.CachedValue
import com.intellij.psi.util.CachedValueProvider
import com.intellij.psi.util.CachedValuesManager
Expand Down Expand Up @@ -277,6 +281,14 @@ fun PsiType.normalize(): PsiType {
return normalized
}

fun PsiType.toObjectType(project: Project): PsiType =
when (val normalized = normalize()) {
is PsiPrimitiveType ->
normalized.getBoxedType(PsiManager.getInstance(project), GlobalSearchScope.allScope(project))
?: normalized
else -> normalized
}

val PsiMethod.nameAndParameterTypes: String
get() = "$name(${parameterList.parameters.joinToString(", ") { it.type.presentableText }})"

Expand Down
Loading