Skip to content

Commit b91cf62

Browse files
fmodestogithub-actions[bot]timtebeekknutwannheden
authored
Add expression type in JavaTemplate (#5384)
* Add expression type * rename * More tests * Extra working test * fix * Update rewrite-java-test/src/test/java/org/openrewrite/java/JavaTemplateGenericsTest.java Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Improve `SemanticallyEqual` for method references * Polish test --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Tim te Beek <[email protected]> Co-authored-by: Knut Wannheden <[email protected]>
1 parent cbe16ac commit b91cf62

File tree

6 files changed

+375
-18
lines changed

6 files changed

+375
-18
lines changed

rewrite-java-test/src/test/java/org/openrewrite/java/JavaTemplateGenericsTest.java

Lines changed: 260 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package org.openrewrite.java;
1717

1818
import org.junit.jupiter.api.Test;
19+
import org.junitpioneer.jupiter.ExpectedToFail;
1920
import org.openrewrite.DocumentExample;
2021
import org.openrewrite.ExecutionContext;
2122
import org.openrewrite.java.tree.Expression;
@@ -50,7 +51,7 @@ void genericTypes() {
5051
rewriteRun(
5152
spec -> spec.recipe(toRecipe(() -> new JavaVisitor<>() {
5253
@Override
53-
public J visitVariableDeclarations(J.VariableDeclarations multiVariable, ExecutionContext executionContext) {
54+
public J visitVariableDeclarations(J.VariableDeclarations multiVariable, ExecutionContext ctx) {
5455
J.VariableDeclarations.NamedVariable variable = multiVariable.getVariables().getFirst();
5556
if ("o".equals(variable.getSimpleName())) {
5657
Expression exp = Objects.requireNonNull(variable.getInitializer());
@@ -64,7 +65,7 @@ public J visitVariableDeclarations(J.VariableDeclarations multiVariable, Executi
6465
assertThat(res4.getMethodType()).isNotNull();
6566
return res3;
6667
}
67-
return super.visitVariableDeclarations(multiVariable, executionContext);
68+
return super.visitVariableDeclarations(multiVariable, ctx);
6869
}
6970
})),
7071
java(
@@ -134,8 +135,8 @@ void recursiveType() {
134135
rewriteRun(
135136
spec -> spec.recipe(toRecipe(() -> new JavaIsoVisitor<>() {
136137
@Override
137-
public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext executionContext) {
138-
return template.matches(getCursor()) ? SearchResult.found(method) : super.visitMethodInvocation(method, executionContext);
138+
public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) {
139+
return template.matches(getCursor()) ? SearchResult.found(method) : super.visitMethodInvocation(method, ctx);
139140
}
140141
})),
141142
java(
@@ -165,6 +166,88 @@ void test() {
165166
);
166167
}
167168

169+
@Test
170+
void setsDifferenceMultimapRecipe() {
171+
rewriteRun(
172+
spec -> spec.parser(JavaParser.fromJavaVersion().classpath("guava"))
173+
.recipe(toRecipe(() -> new JavaIsoVisitor<>() {
174+
final JavaTemplate template = JavaTemplate.builder("#{set:any(java.util.Set<T>)}.stream().filter(java.util.function.Predicate.not(#{multimap:any(com.google.common.collect.Multimap<K, V>)}::containsKey)).collect(com.google.common.collect.ImmutableSet.toImmutableSet())")
175+
.bindType("com.google.common.collect.ImmutableSet<T>")
176+
.genericTypes("T", "K", "V")
177+
.javaParser(JavaParser.fromJavaVersion().classpath("guava"))
178+
.build();
179+
180+
@Override
181+
public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) {
182+
return template.matches(getCursor()) ? SearchResult.found(method) : super.visitMethodInvocation(method, ctx);
183+
}
184+
})),
185+
java(
186+
"""
187+
import com.google.common.collect.ImmutableSet;
188+
import com.google.common.collect.ImmutableSetMultimap;
189+
190+
import java.util.function.Predicate;
191+
192+
class Test {
193+
void test() {
194+
ImmutableSet.of(1).stream().filter(Predicate.not(ImmutableSetMultimap.of(2, 3)::containsKey)).collect(ImmutableSet.toImmutableSet());
195+
}
196+
}
197+
""",
198+
"""
199+
import com.google.common.collect.ImmutableSet;
200+
import com.google.common.collect.ImmutableSetMultimap;
201+
202+
import java.util.function.Predicate;
203+
204+
class Test {
205+
void test() {
206+
/*~~>*/ImmutableSet.of(1).stream().filter(Predicate.not(ImmutableSetMultimap.of(2, 3)::containsKey)).collect(ImmutableSet.toImmutableSet());
207+
}
208+
}
209+
"""
210+
)
211+
);
212+
}
213+
214+
@Test
215+
void emptyStreamRecipe() {
216+
rewriteRun(
217+
spec -> spec.recipe(toRecipe(() -> new JavaIsoVisitor<>() {
218+
final JavaTemplate template = JavaTemplate.builder("java.util.stream.Stream.of()")
219+
.bindType("java.util.stream.Stream<T>")
220+
.genericTypes("T")
221+
.build();
222+
223+
@Override
224+
public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) {
225+
return template.matches(getCursor()) ? SearchResult.found(method) : super.visitMethodInvocation(method, ctx);
226+
}
227+
})),
228+
java(
229+
"""
230+
import java.util.stream.Stream;
231+
232+
class Test {
233+
Stream<String> test() {
234+
return Stream.of();
235+
}
236+
}
237+
""",
238+
"""
239+
import java.util.stream.Stream;
240+
241+
class Test {
242+
Stream<String> test() {
243+
return /*~~>*/Stream.of();
244+
}
245+
}
246+
"""
247+
)
248+
);
249+
}
250+
168251
@Test
169252
void methodModifiersMismatch() {
170253
rewriteRun(
@@ -174,8 +257,8 @@ void methodModifiersMismatch() {
174257
.build();
175258

176259
@Override
177-
public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext executionContext) {
178-
return template.matches(getCursor()) ? SearchResult.found(method) : super.visitMethodInvocation(method, executionContext);
260+
public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) {
261+
return template.matches(getCursor()) ? SearchResult.found(method) : super.visitMethodInvocation(method, ctx);
179262
}
180263
})),
181264
java(
@@ -208,4 +291,175 @@ Stream<Integer> test() {
208291
)
209292
);
210293
}
294+
295+
@Test
296+
void replaceMemberReferenceToLambda() {
297+
//noinspection Convert2MethodRef
298+
rewriteRun(
299+
spec -> spec
300+
.expectedCyclesThatMakeChanges(1).cycles(1)
301+
.recipe(toRecipe(() -> new JavaVisitor<>() {
302+
final JavaTemplate refTemplate = JavaTemplate.builder("T::toString")
303+
.bindType("java.util.function.Function<T, String>")
304+
.genericTypes("T")
305+
.build();
306+
final JavaTemplate lambdaTemplate = JavaTemplate.builder("e -> e.toString()")
307+
.bindType("java.util.function.Function<T, String>")
308+
.genericTypes("T")
309+
.build();
310+
311+
@Override
312+
public J visitMemberReference(J.MemberReference memberRef, ExecutionContext ctx) {
313+
JavaTemplate.Matcher matcher = refTemplate.matcher(getCursor());
314+
if (matcher.find()) {
315+
return lambdaTemplate.apply(getCursor(), memberRef.getCoordinates().replace(), matcher.getMatchResult().getMatchedParameters().toArray());
316+
} else {
317+
return super.visitMemberReference(memberRef, ctx);
318+
}
319+
}
320+
})),
321+
//language=java
322+
java(
323+
"""
324+
import java.util.function.Function;
325+
326+
class Foo {
327+
void test() {
328+
test(Object::toString);
329+
}
330+
331+
void test(Function<Object, String> fn) {
332+
}
333+
}
334+
""",
335+
"""
336+
import java.util.function.Function;
337+
338+
class Foo {
339+
void test() {
340+
test(e -> e.toString());
341+
}
342+
343+
void test(Function<Object, String> fn) {
344+
}
345+
}
346+
"""
347+
)
348+
);
349+
}
350+
351+
@Test
352+
@ExpectedToFail
353+
void replaceLambdaToMemberReference() {
354+
//noinspection Convert2MethodRef
355+
rewriteRun(
356+
spec -> spec
357+
.expectedCyclesThatMakeChanges(1).cycles(1)
358+
.recipe(toRecipe(() -> new JavaVisitor<>() {
359+
final JavaTemplate lambdaTemplate = JavaTemplate.builder("e -> e.toString()")
360+
.bindType("java.util.function.Function<T, String>")
361+
.genericTypes("T")
362+
.build();
363+
final JavaTemplate refTemplate = JavaTemplate.builder("T::toString")
364+
.bindType("java.util.function.Function<T, String>")
365+
.genericTypes("T")
366+
.build();
367+
368+
@Override
369+
public J visitLambda(J.Lambda lambda, ExecutionContext ctx) {
370+
JavaTemplate.Matcher matcher = lambdaTemplate.matcher(getCursor());
371+
if (matcher.find()) {
372+
return refTemplate.apply(getCursor(), lambda.getCoordinates().replace(), matcher.getMatchResult().getMatchedParameters().toArray());
373+
} else {
374+
return super.visitLambda(lambda, ctx);
375+
}
376+
}
377+
})),
378+
//language=java
379+
java(
380+
"""
381+
import java.util.function.Function;
382+
383+
class Foo {
384+
void test() {
385+
test(e -> e.toString());
386+
}
387+
388+
void test(Function<Object, String> fn) {
389+
}
390+
}
391+
""",
392+
"""
393+
import java.util.function.Function;
394+
395+
class Foo {
396+
void test() {
397+
test(Object::toString);
398+
}
399+
400+
void test(Function<Object, String> fn) {
401+
}
402+
}
403+
"""
404+
)
405+
);
406+
}
407+
408+
@Test
409+
void memberReferenceToLambda() {
410+
//noinspection Convert2MethodRef
411+
rewriteRun(
412+
spec -> spec
413+
.expectedCyclesThatMakeChanges(1).cycles(1)
414+
.recipe(toRecipe(() -> new JavaVisitor<>() {
415+
final JavaTemplate refTemplate = JavaTemplate.builder("#{any(java.util.Set<T>)}::contains")
416+
.bindType("java.util.function.Predicate<T>")
417+
.genericTypes("T")
418+
.build();
419+
final JavaTemplate lambdaTemplate = JavaTemplate.builder("e -> #{any(java.util.Set<T>)}.contains(e)")
420+
.bindType("java.util.function.Predicate<T>")
421+
.genericTypes("T")
422+
.build();
423+
424+
@Override
425+
public J visitMemberReference(J.MemberReference memberRef, ExecutionContext ctx) {
426+
JavaTemplate.Matcher matcher = refTemplate.matcher(getCursor());
427+
if (matcher.find()) {
428+
return lambdaTemplate.apply(getCursor(), memberRef.getCoordinates().replace(), matcher.getMatchResult().getMatchedParameters().toArray());
429+
} else {
430+
return super.visitMemberReference(memberRef, ctx);
431+
}
432+
}
433+
})),
434+
//language=java
435+
java(
436+
"""
437+
import java.util.*;
438+
import java.util.function.*;
439+
440+
class Foo {
441+
List<Integer> test(List<Integer> list) {
442+
Set<Integer> set = Set.of(1, 2, 3);
443+
return list.stream()
444+
.filter(set::contains)
445+
.toList();
446+
}
447+
}
448+
""",
449+
"""
450+
import java.util.*;
451+
import java.util.function.*;
452+
453+
class Foo {
454+
List<Integer> test(List<Integer> list) {
455+
Set<Integer> set = Set.of(1, 2, 3);
456+
return list.stream()
457+
.filter(e -> set.contains(e))
458+
.toList();
459+
}
460+
}
461+
"""
462+
)
463+
);
464+
}
211465
}

rewrite-java-test/src/test/java/org/openrewrite/java/JavaTemplateMatchTest.java

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,4 +1046,72 @@ Predicate<Object> test() {
10461046
)
10471047
);
10481048
}
1049+
1050+
@Test
1051+
void matchMemberReferenceAndLambda() {
1052+
//noinspection Convert2MethodRef
1053+
rewriteRun(
1054+
spec -> spec
1055+
.expectedCyclesThatMakeChanges(1).cycles(1)
1056+
.recipe(toRecipe(() -> new JavaVisitor<>() {
1057+
final JavaTemplate refTemplate = JavaTemplate.builder("String::valueOf")
1058+
.bindType("java.util.function.Function<Object, String>")
1059+
.build();
1060+
final JavaTemplate lambdaTemplate = JavaTemplate.builder("(e)->e.toString()")
1061+
.bindType("java.util.function.Function<Object, String>")
1062+
.build();
1063+
1064+
@Override
1065+
public J visitMemberReference(J.MemberReference memberRef, ExecutionContext executionContext) {
1066+
var matcher = refTemplate.matcher(getCursor());
1067+
if (matcher.find()) {
1068+
return lambdaTemplate.apply(getCursor(), memberRef.getCoordinates().replace(), matcher.getMatchResult().getMatchedParameters().toArray());
1069+
} else {
1070+
return super.visitMemberReference(memberRef, executionContext);
1071+
}
1072+
}
1073+
1074+
@Override
1075+
public J visitLambda(J.Lambda lambda, ExecutionContext executionContext) {
1076+
var matcher = lambdaTemplate.matcher(getCursor());
1077+
if (matcher.find()) {
1078+
return refTemplate.apply(getCursor(), lambda.getCoordinates().replace(), matcher.getMatchResult().getMatchedParameters().toArray());
1079+
} else {
1080+
return lambdaTemplate.matches(getCursor()) ? SearchResult.found(lambda, "lambda") : super.visitLambda(lambda, executionContext);
1081+
}
1082+
}
1083+
})),
1084+
//language=java
1085+
java(
1086+
"""
1087+
import java.util.function.Function;
1088+
1089+
class Foo {
1090+
void test() {
1091+
test(String::valueOf);
1092+
test(e -> e.toString());
1093+
test(x -> x.toString());
1094+
}
1095+
1096+
void test(Function<Object, String> fn) {
1097+
}
1098+
}
1099+
""",
1100+
"""
1101+
import java.util.function.Function;
1102+
1103+
class Foo {
1104+
void test() {
1105+
test((e) -> e.toString());
1106+
test(String::valueOf);
1107+
test(String::valueOf);
1108+
}
1109+
1110+
void test(Function<Object, String> fn) {
1111+
}
1112+
}
1113+
"""
1114+
)
1115+
);
1116+
}
10491117
}

0 commit comments

Comments
 (0)