Skip to content

Add expression type in JavaTemplate #5384

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 28, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package org.openrewrite.java;

import org.junit.jupiter.api.Test;
import org.junitpioneer.jupiter.ExpectedToFail;
import org.openrewrite.ExecutionContext;
import org.openrewrite.java.tree.Expression;
import org.openrewrite.java.tree.J;
Expand Down Expand Up @@ -163,6 +164,88 @@ void test() {
);
}

@Test
void setsDifferenceMultimapRecipe() {
rewriteRun(
spec -> spec.parser(JavaParser.fromJavaVersion().classpath("guava"))
.recipe(toRecipe(() -> new JavaIsoVisitor<>() {
final JavaTemplate template = JavaTemplate.builder("#{set:any(java.util.Set<T>)}.stream().filter(java.util.function.Predicate.not(#{multimap:any(com.google.common.collect.Multimap<K, V>)}::containsKey)).collect(com.google.common.collect.ImmutableSet.toImmutableSet())")
.bindType("com.google.common.collect.ImmutableSet<T>")
.genericTypes("T", "K", "V")
.javaParser(JavaParser.fromJavaVersion().classpath("guava"))
.build();

@Override
public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext executionContext) {
return template.matches(getCursor()) ? SearchResult.found(method) : super.visitMethodInvocation(method, executionContext);
}
})),
java(
"""
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSetMultimap;

import java.util.function.Predicate;

class Test {
void test() {
ImmutableSet.of(1).stream().filter(Predicate.not(ImmutableSetMultimap.of(2, 3)::containsKey)).collect(ImmutableSet.toImmutableSet());
}
}
""",
"""
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSetMultimap;

import java.util.function.Predicate;

class Test {
void test() {
/*~~>*/ImmutableSet.of(1).stream().filter(Predicate.not(ImmutableSetMultimap.of(2, 3)::containsKey)).collect(ImmutableSet.toImmutableSet());
}
}
"""
)
);
}

@Test
void emptyStreamRecipe() {
rewriteRun(
spec -> spec.recipe(toRecipe(() -> new JavaIsoVisitor<>() {
final JavaTemplate template = JavaTemplate.builder("java.util.stream.Stream.of()")
.bindType("java.util.stream.Stream<T>")
.genericTypes("T")
.build();

@Override
public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext executionContext) {
return template.matches(getCursor()) ? SearchResult.found(method) : super.visitMethodInvocation(method, executionContext);
}
})),
java(
"""
import java.util.stream.Stream;

class Test {
Stream<String> test() {
return Stream.of();
}
}
""",
"""
import java.util.stream.Stream;

class Test {
Stream<String> test() {
return /*~~>*/Stream.of();
}
}
"""
)
);
}

@Test
void methodModifiersMismatch() {
rewriteRun(
Expand Down Expand Up @@ -206,4 +289,118 @@ Stream<Integer> test() {
)
);
}

@Test
@ExpectedToFail
void replaceMemberReferenceToLambda() {
//noinspection Convert2MethodRef
rewriteRun(
spec -> spec
.expectedCyclesThatMakeChanges(1).cycles(1)
.recipe(toRecipe(() -> new JavaVisitor<>() {
final JavaTemplate refTemplate = JavaTemplate.builder("T::toString")
.bindType("java.util.function.Function<T, String>")
.genericTypes("T")
.build();
final JavaTemplate lambdaTemplate = JavaTemplate.builder("e -> e.toString()")
.bindType("java.util.function.Function<T, String>")
.genericTypes("T")
.build();

@Override
public J visitMemberReference(J.MemberReference memberRef, ExecutionContext executionContext) {
JavaTemplate.Matcher matcher = refTemplate.matcher(getCursor());
if (matcher.find()) {
return lambdaTemplate.apply(getCursor(), memberRef.getCoordinates().replace(), matcher.getMatchResult().getMatchedParameters().toArray());
} else {
return super.visitMemberReference(memberRef, executionContext);
}
}
})),
//language=java
java(
"""
import java.util.function.Function;

class Foo {
void test() {
test(Object::toString);
}

void test(Function<Object, String> fn) {
}
}
""",
"""
import java.util.function.Function;

class Foo {
void test() {
test(e -> e.toString());
}

void test(Function<Object, String> fn) {
}
}
"""
)
);
}

@Test
@ExpectedToFail
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one it's more to showcase a scenario we are not handling, even if the apply had arguments, T is not really a parameter, and #{type:any(T)} would not match a JavaType

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could have some other kind of argument. #{type(T)} where instead of using __P__.<T>any() we could add the TypeUtils.toString() that would also make the test above pass, being an argument the containing would match.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've considered doing exactly this before, but I didn't yet see a use case for it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have try this templates in error-prone refaster and they work. But I haven't found any use on the Picnic tests.

  static final class RefToLambda<T> {
    @BeforeTemplate
    Function<T, String> before() {
      return T::toString;
    }

    @AfterTemplate
    Function<T, String> after() {
      return e -> e.toString();
    }
  }

  static final class LambdaToRef<T> {
    @BeforeTemplate
    Function<T, String> before() {
      return e -> e.toString();
    }

    @AfterTemplate
    Function<T, String> after() {
      return T::toString;
    }
  }

---

  ImmutableSet<Function<Object, String>> testRefToLambda() {
    return ImmutableSet.of(
        Object::toString);
  }

  ImmutableSet<Function<Object, String>> testLambdaToRef() {
    return ImmutableSet.of(
        e -> e.toString());
  }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember the use case now. Having a #{type()} construct could indeed help here, but I don't yet see how we would be using that from a Refaster recipe. The problem as I see it is that there wouldn't be any corresponding parameter in the before template, so how do we correctly generate the recipe to set that parameter?

void replaceLambdaToMemberReference() {
//noinspection Convert2MethodRef
rewriteRun(
spec -> spec
.expectedCyclesThatMakeChanges(1).cycles(1)
.recipe(toRecipe(() -> new JavaVisitor<>() {
final JavaTemplate lambdaTemplate = JavaTemplate.builder("e -> e.toString()")
.bindType("java.util.function.Function<T, String>")
.genericTypes("T")
.build();
final JavaTemplate refTemplate = JavaTemplate.builder("T::toString")
.bindType("java.util.function.Function<T, String>")
.genericTypes("T")
.build();

@Override
public J visitLambda(J.Lambda lambda, ExecutionContext executionContext) {
JavaTemplate.Matcher matcher = lambdaTemplate.matcher(getCursor());
if (matcher.find()) {
return refTemplate.apply(getCursor(), lambda.getCoordinates().replace(), matcher.getMatchResult().getMatchedParameters().toArray());
} else {
return super.visitLambda(lambda, executionContext);
}
}
})),
//language=java
java(
"""
import java.util.function.Function;

class Foo {
void test() {
test(e -> e.toString());
}

void test(Function<Object, String> fn) {
}
}
""",
"""
import java.util.function.Function;

class Foo {
void test() {
test(Object::toString);
}

void test(Function<Object, String> fn) {
}
}
"""
)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1046,4 +1046,72 @@ Predicate<Object> test() {
)
);
}

@Test
void matchMemberReferenceAndLambda() {
//noinspection Convert2MethodRef
rewriteRun(
spec -> spec
.expectedCyclesThatMakeChanges(1).cycles(1)
.recipe(toRecipe(() -> new JavaVisitor<>() {
final JavaTemplate refTemplate = JavaTemplate.builder("String::valueOf")
.bindType("java.util.function.Function<Object, String>")
.build();
final JavaTemplate lambdaTemplate = JavaTemplate.builder("(e)->e.toString()")
.bindType("java.util.function.Function<Object, String>")
.build();

@Override
public J visitMemberReference(J.MemberReference memberRef, ExecutionContext executionContext) {
var matcher = refTemplate.matcher(getCursor());
if (matcher.find()) {
return lambdaTemplate.apply(getCursor(), memberRef.getCoordinates().replace(), matcher.getMatchResult().getMatchedParameters().toArray());
} else {
return super.visitMemberReference(memberRef, executionContext);
}
}

@Override
public J visitLambda(J.Lambda lambda, ExecutionContext executionContext) {
var matcher = lambdaTemplate.matcher(getCursor());
if (matcher.find()) {
return refTemplate.apply(getCursor(), lambda.getCoordinates().replace(), matcher.getMatchResult().getMatchedParameters().toArray());
} else {
return lambdaTemplate.matches(getCursor()) ? SearchResult.found(lambda, "lambda") : super.visitLambda(lambda, executionContext);
}
}
})),
//language=java
java(
"""
import java.util.function.Function;

class Foo {
void test() {
test(String::valueOf);
test(e -> e.toString());
test(x -> x.toString());
}

void test(Function<Object, String> fn) {
}
}
""",
"""
import java.util.function.Function;

class Foo {
void test() {
test((e) -> e.toString());
test(String::valueOf);
test(String::valueOf);
}

void test(Function<Object, String> fn) {
}
}
"""
)
);
}
}
26 changes: 23 additions & 3 deletions rewrite-java/src/main/java/org/openrewrite/java/JavaTemplate.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ protected static Path getTemplateClasspathDir() {
private final Consumer<String> onAfterVariableSubstitution;
private final JavaTemplateParser templateParser;

private JavaTemplate(boolean contextSensitive, JavaParser.Builder<?, ?> parser, String code, Set<String> imports,
private JavaTemplate(boolean contextSensitive, JavaParser.Builder<?, ?> parser, String code, String bindType, Set<String> imports,
Set<String> genericTypes, Consumer<String> onAfterVariableSubstitution, Consumer<String> onBeforeParseTemplate) {
this(code, genericTypes, onAfterVariableSubstitution, new JavaTemplateParser(contextSensitive, augmentClasspath(parser), onAfterVariableSubstitution, onBeforeParseTemplate, imports));
this(code, genericTypes, onAfterVariableSubstitution, new JavaTemplateParser(contextSensitive, augmentClasspath(parser), onAfterVariableSubstitution, onBeforeParseTemplate, imports, bindType));
}

private static JavaParser.Builder<?, ?> augmentClasspath(JavaParser.Builder<?, ?> parserBuilder) {
Expand Down Expand Up @@ -181,6 +181,7 @@ public static class Builder {
private final Set<String> genericTypes = new HashSet<>();

private boolean contextSensitive;
private String bindType = "Object";

private JavaParser.Builder<?, ?> parser = org.openrewrite.java.JavaParser.fromJavaVersion();

Expand Down Expand Up @@ -212,6 +213,25 @@ public Builder contextSensitive() {
return this;
}

/**
* In context-free templates involving generic types, the type often cannot be inferred automatically.
* <p>
* Common examples include:
* <ul>
* <li>{@code new ArrayList<>()}</li>
* <li>{@code Collections.emptyList()}</li>
* <li>{@code String::valueOf}</li>
* </ul>
* In such cases, the type must be specified manually.
*/
public Builder bindType(String bindType) {
if (StringUtils.isBlank(bindType)) {
throw new IllegalArgumentException("Type must not be blank");
}
this.bindType = bindType;
return this;
}

public Builder imports(String... fullyQualifiedTypeNames) {
for (String typeName : fullyQualifiedTypeNames) {
validateImport(typeName);
Expand Down Expand Up @@ -259,7 +279,7 @@ public Builder doBeforeParseTemplate(Consumer<String> beforeParseTemplate) {
}

public JavaTemplate build() {
return new JavaTemplate(contextSensitive, parser.clone(), code, imports, genericTypes,
return new JavaTemplate(contextSensitive, parser.clone(), code, bindType, imports, genericTypes,
onAfterVariableSubstitution, onBeforeParseTemplate);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ public class BlockStatementTemplateGenerator {

protected final Set<String> imports;
private final boolean contextSensitive;
private final String bindType;

public BlockStatementTemplateGenerator(Set<String> imports, boolean contextSensitive) {
this(imports, contextSensitive, "Object");
}

public String template(Cursor cursor, String template, Collection<JavaType.GenericTypeVariable> typeVariables, Space.Location location, JavaCoordinates.Mode mode) {
//noinspection ConstantConditions
Expand Down Expand Up @@ -205,24 +210,26 @@ private boolean isTemplateStopComment(Comment comment) {
protected void contextFreeTemplate(Cursor cursor, J j, Collection<JavaType.GenericTypeVariable> typeVariables, StringBuilder before, StringBuilder after) {
String classDeclaration = typeVariables.isEmpty() ? "Template" :
"Template<" + typeVariables.stream().map(TypeUtils::toGenericTypeString).collect(Collectors.joining(", ")) + ">";
if (j instanceof J.Lambda) {
if (j instanceof J.Lambda && "Object".equals(bindType)) {
throw new IllegalArgumentException(
"Templating a lambda requires a cursor so that it can be properly parsed and type-attributed. " +
"Mark this template as context-sensitive by calling JavaTemplate.Builder#contextSensitive().");
} else if (j instanceof J.MemberReference) {
"Mark this template as context-sensitive by calling JavaTemplate.Builder#contextSensitive() or " +
"specify the type by calling JavaTemplate.Builder#bindType()");
} else if (j instanceof J.MemberReference && "Object".equals(bindType)) {
throw new IllegalArgumentException(
"Templating a method reference requires a cursor so that it can be properly parsed and type-attributed. " +
"Mark this template as context-sensitive by calling JavaTemplate.Builder#contextSensitive().");
"Mark this template as context-sensitive by calling JavaTemplate.Builder#contextSensitive() or " +
"specify the type by calling JavaTemplate.Builder#bindType()");
} else if (j instanceof J.MethodInvocation) {
before.insert(0, String.format("class %s {{\n", classDeclaration));
JavaType.Method methodType = ((J.MethodInvocation) j).getMethodType();
if (methodType == null || methodType.getReturnType() != JavaType.Primitive.Void) {
before.append("Object o = ");
before.append(bindType).append(" o = ");
}
after.append(";\n}}");
} else if (j instanceof Expression && !(j instanceof J.Assignment)) {
before.insert(0, String.format("class %s {\n", classDeclaration));
before.append("Object o = ");
before.append(bindType).append(" o = ");
after.append(";\n}");
} else if ((j instanceof J.MethodDeclaration || j instanceof J.VariableDeclarations || j instanceof J.Block || j instanceof J.ClassDeclaration) &&
cursor.getValue() instanceof J.Block &&
Expand Down
Loading
Loading