Skip to content

Commit 4b4d455

Browse files
committed
Support first_statement coordinate for J.Block
1 parent 18a73f2 commit 4b4d455

File tree

3 files changed

+34
-2
lines changed

3 files changed

+34
-2
lines changed

rewrite/rewrite/java/support_types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,11 @@ def replace(self, loc: Optional[Space.Location] = None) -> JavaCoordinates:
359359

360360
@dataclass
361361
class _BlockCoordinateBuilder(_StatementCoordinateBuilder):
362+
def first_statement(self) -> JavaCoordinates:
363+
if not self.tree.statements:
364+
return self.last_statement()
365+
return self.tree.statements[0].get_coordinates().before()
366+
362367
def last_statement(self) -> JavaCoordinates:
363368
return self.before(Space.Location.BLOCK_END)
364369

rewrite/rewrite/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,10 @@ def list_flat_map(fn: FlatMapFnType[T], lst: List[T]) -> List[T]:
5555
if isinstance(new_items, list) and (len(new_items) != 1 or new_items[0] is not item):
5656
changed = True
5757
result.extend(new_items)
58-
elif not isinstance(new_items, list) and new_items is not item:
58+
elif not isinstance(new_items, list):
59+
if changed or new_items is not item:
60+
result.append(new_items)
5961
changed = True
60-
result.append(new_items)
6162

6263
return result if changed else lst
6364

rewrite/tests/python/all/templating/template_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,32 @@ def f():
8282
)
8383

8484

85+
def test_add_statement_first():
86+
rewrite_run(
87+
# language=python
88+
python(
89+
"""\
90+
def f():
91+
pass
92+
pass
93+
""",
94+
"""\
95+
def f():
96+
return
97+
pass
98+
pass
99+
"""
100+
),
101+
spec=RecipeSpec()
102+
.with_recipe(from_visitor(
103+
GenericTemplatingVisitor(
104+
lambda j: isinstance(j, MethodDeclaration) and len(j.body.statements) == 2,
105+
'return',
106+
coordinate_provider=lambda m: cast(MethodDeclaration, m).body.get_coordinates().first_statement())
107+
))
108+
)
109+
110+
85111
def test_add_statement_before():
86112
rewrite_run(
87113
# language=python

0 commit comments

Comments
 (0)