Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package org.openrewrite.java;

import org.junit.jupiter.api.Test;
import org.junitpioneer.jupiter.ExpectedToFail;
import org.openrewrite.ExecutionContext;
import org.openrewrite.java.tree.Expression;
import org.openrewrite.java.tree.J;
Expand Down Expand Up @@ -163,6 +164,88 @@ void test() {
);
}

@Test
void setsDifferenceMultimapRecipe() {
rewriteRun(
spec -> spec.parser(JavaParser.fromJavaVersion().classpath("guava"))
.recipe(toRecipe(() -> new JavaIsoVisitor<>() {
final JavaTemplate template = JavaTemplate.builder("#{set:any(java.util.Set<T>)}.stream().filter(java.util.function.Predicate.not(#{multimap:any(com.google.common.collect.Multimap<K, V>)}::containsKey)).collect(com.google.common.collect.ImmutableSet.toImmutableSet())")
.bindType("com.google.common.collect.ImmutableSet<T>")
.genericTypes("T", "K", "V")
.javaParser(JavaParser.fromJavaVersion().classpath("guava"))
.build();

@Override
public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext executionContext) {
return template.matches(getCursor()) ? SearchResult.found(method) : super.visitMethodInvocation(method, executionContext);
}
})),
java(
"""
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSetMultimap;

import java.util.function.Predicate;

class Test {
void test() {
ImmutableSet.of(1).stream().filter(Predicate.not(ImmutableSetMultimap.of(2, 3)::containsKey)).collect(ImmutableSet.toImmutableSet());
}
}
""",
"""
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSetMultimap;

import java.util.function.Predicate;

class Test {
void test() {
/*~~>*/ImmutableSet.of(1).stream().filter(Predicate.not(ImmutableSetMultimap.of(2, 3)::containsKey)).collect(ImmutableSet.toImmutableSet());
}
}
"""
)
);
}

@Test
void emptyStreamRecipe() {
rewriteRun(
spec -> spec.recipe(toRecipe(() -> new JavaIsoVisitor<>() {
final JavaTemplate template = JavaTemplate.builder("java.util.stream.Stream.of()")
.bindType("java.util.stream.Stream<T>")
.genericTypes("T")
.build();

@Override
public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext executionContext) {
return template.matches(getCursor()) ? SearchResult.found(method) : super.visitMethodInvocation(method, executionContext);
}
})),
java(
"""
import java.util.stream.Stream;

class Test {
Stream<String> test() {
return Stream.of();
}
}
""",
"""
import java.util.stream.Stream;

class Test {
Stream<String> test() {
return /*~~>*/Stream.of();
}
}
"""
)
);
}

@Test
void methodModifiersMismatch() {
rewriteRun(
Expand Down Expand Up @@ -206,4 +289,118 @@ Stream<Integer> test() {
)
);
}

@Test
@ExpectedToFail
void replaceMemberReferenceToLambda() {
//noinspection Convert2MethodRef
rewriteRun(
spec -> spec
.expectedCyclesThatMakeChanges(1).cycles(1)
.recipe(toRecipe(() -> new JavaVisitor<>() {
final JavaTemplate refTemplate = JavaTemplate.builder("T::toString")
.bindType("java.util.function.Function<T, String>")
.genericTypes("T")
.build();
final JavaTemplate lambdaTemplate = JavaTemplate.builder("e -> e.toString()")
.bindType("java.util.function.Function<T, String>")
.genericTypes("T")
.build();

@Override
public J visitMemberReference(J.MemberReference memberRef, ExecutionContext executionContext) {
JavaTemplate.Matcher matcher = refTemplate.matcher(getCursor());
if (matcher.find()) {
return lambdaTemplate.apply(getCursor(), memberRef.getCoordinates().replace(), matcher.getMatchResult().getMatchedParameters().toArray());
} else {
return super.visitMemberReference(memberRef, executionContext);
}
}
})),
//language=java
java(
"""
import java.util.function.Function;

class Foo {
void test() {
test(Object::toString);
}

void test(Function<Object, String> fn) {
}
}
""",
"""
import java.util.function.Function;

class Foo {
void test() {
test(e -> e.toString());
}

void test(Function<Object, String> fn) {
}
}
"""
)
);
}

@Test
@ExpectedToFail
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one it's more to showcase a scenario we are not handling, even if the apply had arguments, T is not really a parameter, and #{type:any(T)} would not match a JavaType

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could have some other kind of argument. #{type(T)} where instead of using __P__.<T>any() we could add the TypeUtils.toString() that would also make the test above pass, being an argument the containing would match.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've considered doing exactly this before, but I didn't yet see a use case for it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have try this templates in error-prone refaster and they work. But I haven't found any use on the Picnic tests.

  static final class RefToLambda<T> {
    @BeforeTemplate
    Function<T, String> before() {
      return T::toString;
    }

    @AfterTemplate
    Function<T, String> after() {
      return e -> e.toString();
    }
  }

  static final class LambdaToRef<T> {
    @BeforeTemplate
    Function<T, String> before() {
      return e -> e.toString();
    }

    @AfterTemplate
    Function<T, String> after() {
      return T::toString;
    }
  }

---

  ImmutableSet<Function<Object, String>> testRefToLambda() {
    return ImmutableSet.of(
        Object::toString);
  }

  ImmutableSet<Function<Object, String>> testLambdaToRef() {
    return ImmutableSet.of(
        e -> e.toString());
  }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember the use case now. Having a #{type()} construct could indeed help here, but I don't yet see how we would be using that from a Refaster recipe. The problem as I see it is that there wouldn't be any corresponding parameter in the before template, so how do we correctly generate the recipe to set that parameter?

void replaceLambdaToMemberReference() {
//noinspection Convert2MethodRef
rewriteRun(
spec -> spec
.expectedCyclesThatMakeChanges(1).cycles(1)
.recipe(toRecipe(() -> new JavaVisitor<>() {
final JavaTemplate lambdaTemplate = JavaTemplate.builder("e -> e.toString()")
.bindType("java.util.function.Function<T, String>")
.genericTypes("T")
.build();
final JavaTemplate refTemplate = JavaTemplate.builder("T::toString")
.bindType("java.util.function.Function<T, String>")
.genericTypes("T")
.build();

@Override
public J visitLambda(J.Lambda lambda, ExecutionContext executionContext) {
JavaTemplate.Matcher matcher = lambdaTemplate.matcher(getCursor());
if (matcher.find()) {
return refTemplate.apply(getCursor(), lambda.getCoordinates().replace(), matcher.getMatchResult().getMatchedParameters().toArray());
} else {
return super.visitLambda(lambda, executionContext);
}
}
})),
//language=java
java(
"""
import java.util.function.Function;

class Foo {
void test() {
test(e -> e.toString());
}

void test(Function<Object, String> fn) {
}
}
""",
"""
import java.util.function.Function;

class Foo {
void test() {
test(Object::toString);
}

void test(Function<Object, String> fn) {
}
}
"""
)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1046,4 +1046,72 @@ Predicate<Object> test() {
)
);
}

@Test
void matchMemberReferenceAndLambda() {
//noinspection Convert2MethodRef
rewriteRun(
spec -> spec
.expectedCyclesThatMakeChanges(1).cycles(1)
.recipe(toRecipe(() -> new JavaVisitor<>() {
final JavaTemplate refTemplate = JavaTemplate.builder("String::valueOf")
.bindType("java.util.function.Function<Object, String>")
.build();
final JavaTemplate lambdaTemplate = JavaTemplate.builder("(e)->e.toString()")
.bindType("java.util.function.Function<Object, String>")
.build();

@Override
public J visitMemberReference(J.MemberReference memberRef, ExecutionContext executionContext) {
var matcher = refTemplate.matcher(getCursor());
if (matcher.find()) {
return lambdaTemplate.apply(getCursor(), memberRef.getCoordinates().replace(), matcher.getMatchResult().getMatchedParameters().toArray());
} else {
return super.visitMemberReference(memberRef, executionContext);
}
}

@Override
public J visitLambda(J.Lambda lambda, ExecutionContext executionContext) {
var matcher = lambdaTemplate.matcher(getCursor());
if (matcher.find()) {
return refTemplate.apply(getCursor(), lambda.getCoordinates().replace(), matcher.getMatchResult().getMatchedParameters().toArray());
} else {
return lambdaTemplate.matches(getCursor()) ? SearchResult.found(lambda, "lambda") : super.visitLambda(lambda, executionContext);
}
}
})),
//language=java
java(
"""
import java.util.function.Function;

class Foo {
void test() {
test(String::valueOf);
test(e -> e.toString());
test(x -> x.toString());
}

void test(Function<Object, String> fn) {
}
}
""",
"""
import java.util.function.Function;

class Foo {
void test() {
test((e) -> e.toString());
test(String::valueOf);
test(String::valueOf);
}

void test(Function<Object, String> fn) {
}
}
"""
)
);
}
}
26 changes: 23 additions & 3 deletions rewrite-java/src/main/java/org/openrewrite/java/JavaTemplate.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ protected static Path getTemplateClasspathDir() {
private final Consumer<String> onAfterVariableSubstitution;
private final JavaTemplateParser templateParser;

private JavaTemplate(boolean contextSensitive, JavaParser.Builder<?, ?> parser, String code, Set<String> imports,
private JavaTemplate(boolean contextSensitive, JavaParser.Builder<?, ?> parser, String code, String bindType, Set<String> imports,
Set<String> genericTypes, Consumer<String> onAfterVariableSubstitution, Consumer<String> onBeforeParseTemplate) {
this(code, genericTypes, onAfterVariableSubstitution, new JavaTemplateParser(contextSensitive, augmentClasspath(parser), onAfterVariableSubstitution, onBeforeParseTemplate, imports));
this(code, genericTypes, onAfterVariableSubstitution, new JavaTemplateParser(contextSensitive, augmentClasspath(parser), onAfterVariableSubstitution, onBeforeParseTemplate, imports, bindType));
}

private static JavaParser.Builder<?, ?> augmentClasspath(JavaParser.Builder<?, ?> parserBuilder) {
Expand Down Expand Up @@ -181,6 +181,7 @@ public static class Builder {
private final Set<String> genericTypes = new HashSet<>();

private boolean contextSensitive;
private String bindType = "Object";

private JavaParser.Builder<?, ?> parser = org.openrewrite.java.JavaParser.fromJavaVersion();

Expand Down Expand Up @@ -212,6 +213,25 @@ public Builder contextSensitive() {
return this;
}

/**
* In context-free templates involving generic types, the type often cannot be inferred automatically.
* <p>
* Common examples include:
* <ul>
* <li>{@code new ArrayList<>()}</li>
* <li>{@code Collections.emptyList()}</li>
* <li>{@code String::valueOf}</li>
* </ul>
* In such cases, the type must be specified manually.
*/
public Builder bindType(String bindType) {
if (StringUtils.isBlank(bindType)) {
throw new IllegalArgumentException("Type must not be blank");
}
this.bindType = bindType;
return this;
}

public Builder imports(String... fullyQualifiedTypeNames) {
for (String typeName : fullyQualifiedTypeNames) {
validateImport(typeName);
Expand Down Expand Up @@ -259,7 +279,7 @@ public Builder doBeforeParseTemplate(Consumer<String> beforeParseTemplate) {
}

public JavaTemplate build() {
return new JavaTemplate(contextSensitive, parser.clone(), code, imports, genericTypes,
return new JavaTemplate(contextSensitive, parser.clone(), code, bindType, imports, genericTypes,
onAfterVariableSubstitution, onBeforeParseTemplate);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ public class BlockStatementTemplateGenerator {

protected final Set<String> imports;
private final boolean contextSensitive;
private final String bindType;

public BlockStatementTemplateGenerator(Set<String> imports, boolean contextSensitive) {
this(imports, contextSensitive, "Object");
}

public String template(Cursor cursor, String template, Collection<JavaType.GenericTypeVariable> typeVariables, Space.Location location, JavaCoordinates.Mode mode) {
//noinspection ConstantConditions
Expand Down Expand Up @@ -205,24 +210,26 @@ private boolean isTemplateStopComment(Comment comment) {
protected void contextFreeTemplate(Cursor cursor, J j, Collection<JavaType.GenericTypeVariable> typeVariables, StringBuilder before, StringBuilder after) {
String classDeclaration = typeVariables.isEmpty() ? "Template" :
"Template<" + typeVariables.stream().map(TypeUtils::toGenericTypeString).collect(Collectors.joining(", ")) + ">";
if (j instanceof J.Lambda) {
if (j instanceof J.Lambda && "Object".equals(bindType)) {
throw new IllegalArgumentException(
"Templating a lambda requires a cursor so that it can be properly parsed and type-attributed. " +
"Mark this template as context-sensitive by calling JavaTemplate.Builder#contextSensitive().");
} else if (j instanceof J.MemberReference) {
"Mark this template as context-sensitive by calling JavaTemplate.Builder#contextSensitive() or " +
"specify the type by calling JavaTemplate.Builder#bindType()");
} else if (j instanceof J.MemberReference && "Object".equals(bindType)) {
throw new IllegalArgumentException(
"Templating a method reference requires a cursor so that it can be properly parsed and type-attributed. " +
"Mark this template as context-sensitive by calling JavaTemplate.Builder#contextSensitive().");
"Mark this template as context-sensitive by calling JavaTemplate.Builder#contextSensitive() or " +
"specify the type by calling JavaTemplate.Builder#bindType()");
} else if (j instanceof J.MethodInvocation) {
before.insert(0, String.format("class %s {{\n", classDeclaration));
JavaType.Method methodType = ((J.MethodInvocation) j).getMethodType();
if (methodType == null || methodType.getReturnType() != JavaType.Primitive.Void) {
before.append("Object o = ");
before.append(bindType).append(" o = ");
}
after.append(";\n}}");
} else if (j instanceof Expression && !(j instanceof J.Assignment)) {
before.insert(0, String.format("class %s {\n", classDeclaration));
before.append("Object o = ");
before.append(bindType).append(" o = ");
after.append(";\n}");
} else if ((j instanceof J.MethodDeclaration || j instanceof J.VariableDeclarations || j instanceof J.Block || j instanceof J.ClassDeclaration) &&
cursor.getValue() instanceof J.Block &&
Expand Down
Loading
Loading