diff --git a/rewrite-java-test/src/test/java/org/openrewrite/java/ChangeTypeTest.java b/rewrite-java-test/src/test/java/org/openrewrite/java/ChangeTypeTest.java index 936acbc7e22..1c6d504c488 100644 --- a/rewrite-java-test/src/test/java/org/openrewrite/java/ChangeTypeTest.java +++ b/rewrite-java-test/src/test/java/org/openrewrite/java/ChangeTypeTest.java @@ -1592,8 +1592,8 @@ public A1 method(A1 a1) { import a.A2; public class Example { - public A2 method(A2 a1) { - return a1; + public A2 method(A2 a2) { + return a2; } } """ @@ -1630,6 +1630,49 @@ public class Test { ); } + @Test + void doNotRenameRandomVariablesMatchingClassName() { + rewriteRun( + spec -> spec.recipe(new ChangeType("a.A1", "a.A2", false)), + java( + """ + package a; + public class A1 { + } + """, + """ + package a; + public class A2 { + } + """ + ), + java( + """ + package org.foo; + + import a.A1; + + public class Example { + public String method(A1 a, String a1) { + return a1; + } + } + """, + """ + package org.foo; + + import a.A2; + + public class Example { + public String method(A2 a, String a1) { + return a1; + } + } + """ + ) + ); + } + @Test void updateVariableType() { rewriteRun( diff --git a/rewrite-java/src/main/java/org/openrewrite/java/ChangeType.java b/rewrite-java/src/main/java/org/openrewrite/java/ChangeType.java index f4b02a7fd9a..eac4167fcc3 100644 --- a/rewrite-java/src/main/java/org/openrewrite/java/ChangeType.java +++ b/rewrite-java/src/main/java/org/openrewrite/java/ChangeType.java @@ -20,6 +20,7 @@ import org.jspecify.annotations.Nullable; import org.openrewrite.*; import org.openrewrite.internal.ListUtils; +import org.openrewrite.java.VariableNameUtils.GenerationStrategy; import org.openrewrite.java.search.UsesType; import org.openrewrite.java.tree.*; import org.openrewrite.marker.Markers; @@ -368,6 +369,23 @@ public J visitIdentifier(J.Identifier ident, ExecutionContext ctx) { } } + // Rename variable if it matches class name (starting with a lowercase character) + if (ident.getSimpleName().equals(decapitalize(className))) { + if (targetType instanceof JavaType.FullyQualified) { + String newName = VariableNameUtils.generateVariableName( + decapitalize(((JavaType.FullyQualified) targetType).getClassName()), + getCursor(), + GenerationStrategy.INCREMENT_NUMBER + ); + + ident = ident.withSimpleName(newName); + + if (ident.getFieldType() != null) { + ident = ident.withFieldType(ident.getFieldType().withName(newName)); + } + } + } + // Recreate any static imports as needed if (sf != null) { for (J.Import anImport : sf.getImports()) { @@ -387,6 +405,15 @@ public J visitIdentifier(J.Identifier ident, ExecutionContext ctx) { return visitAndCast(ident, ctx, super::visitIdentifier); } + @Override + public J.VariableDeclarations.NamedVariable visitVariable(J.VariableDeclarations.NamedVariable variable, ExecutionContext ctx) { + J.VariableDeclarations.NamedVariable v = (J.VariableDeclarations.NamedVariable) super.visitVariable(variable, ctx); + if (v.getVariableType() != null && !v.getSimpleName().equals(v.getVariableType().getName())) { + return v.withVariableType(v.getVariableType().withName(v.getSimpleName())); + } + return v; + } + @Override public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { if (method.getMethodType() != null && method.getMethodType().hasFlags(Flag.Static)) { @@ -778,4 +805,11 @@ private static boolean hasSameFQN(J.Import import_, JavaType targetType) { return fqn != null && fqn.equals(curFqn); } + + private static String decapitalize(@Nullable String string) { + if (string != null && !string.isEmpty()) { + return Character.toLowerCase(string.charAt(0)) + string.substring(1); + } + return ""; + } } diff --git a/rewrite-kotlin/src/test/java/org/openrewrite/kotlin/ChangeTypeTest.java b/rewrite-kotlin/src/test/java/org/openrewrite/kotlin/ChangeTypeTest.java index 59c163d443c..df16c863fb4 100644 --- a/rewrite-kotlin/src/test/java/org/openrewrite/kotlin/ChangeTypeTest.java +++ b/rewrite-kotlin/src/test/java/org/openrewrite/kotlin/ChangeTypeTest.java @@ -195,7 +195,7 @@ fun test(original: Original) { } import x.y.Target - fun test(original: Target) { } + fun test(target: Target) { } """ ) ); @@ -231,7 +231,7 @@ fun test(original: MyAlias) { } import x.y.Target as MyAlias - fun test(original: MyAlias) { } + fun test(target: MyAlias) { } """ ) ); @@ -263,7 +263,7 @@ fun test(original: a.b.Original) { } import x.y.Target - fun test(original: Target) { } + fun test(target: Target) { } """ ) );