Skip to content

⚡️ Speed up function function_has_return_statement by 41% #451

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

codeflash-ai[bot]
Copy link
Contributor

@codeflash-ai codeflash-ai bot commented Jun 27, 2025

📄 41% (0.41x) speedup for function_has_return_statement in codeflash/discovery/functions_to_optimize.py

⏱️ Runtime : 3.69 milliseconds 2.62 milliseconds (best of 67 runs)

📝 Explanation and details

Here's an optimized version of your program focusing on reducing per-node overhead and the significant cost of calling ast.iter_child_nodes(node) inside your traversal loop (which accounts for almost 80% of total runtime).

Optimization Strategies

  • Inline the implementation of ast.iter_child_nodes instead of calling the function for every node. This saves significant overhead (as the stdlib implementation uses getattr, a generator and repeated attribute accesses).
  • Use a deque for stack to benefit from very fast pops from the right end instead of pop from a Python list.
  • Use local variable lookups wherever possible (stack_pop = stack.pop trick) to avoid repeated attribute access on the hot path.
  • Do not break the function signature or semantics.


Summary of improvements

  • No generator allocation or function call for each child visit: fast_iter_child_nodes is inlined and avoids unnecessary attribute access inside collections.
  • Deque is much faster for stack pops than a Python list.
  • Reused bound methods for stack operations to minimize attribute lookup overhead.
  • No change in output, preserves semantics, and still stops immediately after first Return.

This rewrite will make a significant speed difference especially for large function bodies.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 76 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
from __future__ import annotations

import ast
from _ast import AsyncFunctionDef, FunctionDef

# imports
import pytest  # used for our unit tests
from codeflash.discovery.functions_to_optimize import \
    function_has_return_statement

# ---------------------------
# Helper function for tests
# ---------------------------

def get_first_funcdef(source: str) -> FunctionDef | AsyncFunctionDef:
    """
    Parse source code and return the first FunctionDef or AsyncFunctionDef node.
    """
    tree = ast.parse(source)
    for node in ast.walk(tree):
        if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
            return node
    raise ValueError("No function definition found in source.")

# ---------------------------
# Basic Test Cases
# ---------------------------

def test_simple_function_with_return():
    # Function with a single return statement
    src = """
def foo():
    return 42
"""
    func_node = get_first_funcdef(src)
    codeflash_output = function_has_return_statement(func_node)

def test_simple_function_without_return():
    # Function with no return statement
    src = """
def foo():
    x = 5
    y = x + 2
"""
    func_node = get_first_funcdef(src)
    codeflash_output = function_has_return_statement(func_node)

def test_function_with_multiple_returns():
    # Function with multiple return statements
    src = """
def foo(x):
    if x > 0:
        return 1
    else:
        return -1
"""
    func_node = get_first_funcdef(src)
    codeflash_output = function_has_return_statement(func_node)

def test_function_with_return_in_if():
    # Function with return inside an if block
    src = """
def foo(x):
    if x > 0:
        return x
    y = 2
"""
    func_node = get_first_funcdef(src)
    codeflash_output = function_has_return_statement(func_node)

def test_function_with_return_in_loop():
    # Function with return inside a for loop
    src = """
def foo(lst):
    for x in lst:
        if x == 0:
            return x
    return None
"""
    func_node = get_first_funcdef(src)
    codeflash_output = function_has_return_statement(func_node)

def test_function_with_return_none():
    # Function with explicit 'return None'
    src = """
def foo():
    return None
"""
    func_node = get_first_funcdef(src)
    codeflash_output = function_has_return_statement(func_node)

# ---------------------------
# Edge Test Cases
# ---------------------------

def test_function_with_return_in_nested_function():
    # Return is only in a nested function, not the outer one
    src = """
def outer():
    def inner():
        return 1
    x = 2
"""
    func_node = get_first_funcdef(src)
    # Should be False: only inner() has a return
    codeflash_output = function_has_return_statement(func_node)

def test_function_with_return_in_nested_class_method():
    # Return is only in a method of a nested class
    src = """
def outer():
    class Inner:
        def method(self):
            return 1
    x = 2
"""
    func_node = get_first_funcdef(src)
    # Should be False: only method() has a return
    codeflash_output = function_has_return_statement(func_node)

def test_function_with_return_in_try_except_finally():
    # Return in try, except, and finally blocks
    src = """
def foo(x):
    try:
        return x
    except Exception:
        return -1
    finally:
        return 0
"""
    func_node = get_first_funcdef(src)
    codeflash_output = function_has_return_statement(func_node)

def test_function_with_return_in_comprehension():
    # Return inside a comprehension (should not exist, but test for robustness)
    src = """
def foo():
    x = [i for i in range(5)]
"""
    func_node = get_first_funcdef(src)
    codeflash_output = function_has_return_statement(func_node)

def test_function_with_pass_only():
    # Function with only pass
    src = """
def foo():
    pass
"""
    func_node = get_first_funcdef(src)
    codeflash_output = function_has_return_statement(func_node)

def test_function_with_docstring_only():
    # Function with only a docstring
    src = '''
def foo():
    """This is a docstring."""
'''
    func_node = get_first_funcdef(src)
    codeflash_output = function_has_return_statement(func_node)

def test_function_with_yield_but_no_return():
    # Function with yield but no return
    src = """
def foo():
    yield 1
"""
    func_node = get_first_funcdef(src)
    codeflash_output = function_has_return_statement(func_node)

def test_function_with_return_in_lambda():
    # Return in a lambda (should not count as function return)
    src = """
def foo():
    x = lambda y: y + 1
"""
    func_node = get_first_funcdef(src)
    codeflash_output = function_has_return_statement(func_node)

def test_async_function_with_return():
    # Async function with return
    src = """
async def foo():
    return 123
"""
    func_node = get_first_funcdef(src)
    codeflash_output = function_has_return_statement(func_node)

def test_async_function_without_return():
    # Async function with no return
    src = """
async def foo():
    await bar()
"""
    func_node = get_first_funcdef(src)
    codeflash_output = function_has_return_statement(func_node)

def test_function_with_return_in_match_case():
    # Return inside match/case (Python 3.10+)
    src = """
def foo(x):
    match x:
        case 1:
            return 'one'
        case _:
            return 'other'
"""
    func_node = get_first_funcdef(src)
    codeflash_output = function_has_return_statement(func_node)

def test_function_with_return_in_while():
    # Return inside a while loop
    src = """
def foo(x):
    while x > 0:
        return x
    return 0
"""
    func_node = get_first_funcdef(src)
    codeflash_output = function_has_return_statement(func_node)

def test_function_with_return_in_deeply_nested_blocks():
    # Return deeply nested in if/for/while
    src = """
def foo(lst):
    for x in lst:
        if x > 0:
            while x < 100:
                if x % 17 == 0:
                    return x
    return None
"""
    func_node = get_first_funcdef(src)
    codeflash_output = function_has_return_statement(func_node)

def test_function_with_return_in_try_except_else():
    # Return in try, except, and else blocks
    src = """
def foo(x):
    try:
        y = 1 / x
    except ZeroDivisionError:
        return None
    else:
        return y
"""
    func_node = get_first_funcdef(src)
    codeflash_output = function_has_return_statement(func_node)

def test_function_with_return_in_decorator():
    # Function with a decorator, return inside function
    src = """
@staticmethod
def foo():
    return 42
"""
    func_node = get_first_funcdef(src)
    codeflash_output = function_has_return_statement(func_node)

def test_function_with_return_in_decorator_but_not_body():
    # Function with a decorator, but no return in function body
    src = """
@staticmethod
def foo():
    x = 1
"""
    func_node = get_first_funcdef(src)
    codeflash_output = function_has_return_statement(func_node)

# ---------------------------
# Large Scale Test Cases
# ---------------------------

def test_large_function_with_no_return():
    # Large function with many statements but no return
    body = "\n".join([f"    x{i} = {i}" for i in range(900)])
    src = f"""
def foo():
{body}
"""
    func_node = get_first_funcdef(src)
    codeflash_output = function_has_return_statement(func_node)

def test_large_function_with_return_at_end():
    # Large function with many statements and return at the end
    body = "\n".join([f"    x{i} = {i}" for i in range(900)])
    src = f"""
def foo():
{body}
    return 999
"""
    func_node = get_first_funcdef(src)
    codeflash_output = function_has_return_statement(func_node)

def test_large_function_with_return_in_middle():
    # Large function with return in the middle
    body = "\n".join([f"    x{i} = {i}" for i in range(450)])
    src = f"""
def foo():
{body}
    return 'mid'
"""
    func_node = get_first_funcdef(src)
    codeflash_output = function_has_return_statement(func_node)


def foo():
{body}
{nested}
"""
    func_node = get_first_funcdef(src)
    codeflash_output = function_has_return_statement(func_node)

def test_large_function_with_many_small_functions_inside():
    # Large function with many nested function definitions, only inner functions have returns
    inner_funcs = "\n".join([
        f"    def inner_{i}():\n        return {i}" for i in range(50)
    ])
    src = f"""
def foo():
{inner_funcs}
    x = 1
"""
    func_node = get_first_funcdef(src)
    # Outer function has no return, only inner functions do
    codeflash_output = function_has_return_statement(func_node)

def test_large_function_with_comments_and_blank_lines():
    # Large function with many comments and blank lines, and a return at the end
    body = "\n".join([
        "    # comment" if i % 2 == 0 else "" for i in range(500)
    ])
    src = f"""
def foo():
{body}
    return 'done'
"""
    func_node = get_first_funcdef(src)
    codeflash_output = function_has_return_statement(func_node)

def test_large_function_with_return_only_in_nested_class_methods():
    # Large function with a class definition inside, only class methods have returns
    class_methods = "\n".join([
        f"        def method_{i}(self):\n            return {i}" for i in range(20)
    ])
    src = f"""
def foo():
    class Inner:
{class_methods}
    x = 2
"""
    func_node = get_first_funcdef(src)
    # Outer function has no return, only class methods do
    codeflash_output = function_has_return_statement(func_node)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

from __future__ import annotations

import ast
from _ast import AsyncFunctionDef, FunctionDef

# imports
import pytest  # used for our unit tests
from codeflash.discovery.functions_to_optimize import \
    function_has_return_statement


# Helper to extract the first function node from source code
def get_first_function_node(src: str) -> FunctionDef | AsyncFunctionDef:
    """
    Parses the source code and returns the first FunctionDef or AsyncFunctionDef node.
    Raises ValueError if no function is found.
    """
    module = ast.parse(src)
    for node in module.body:
        if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
            return node
    raise ValueError("No function found in source")

# ----------------------
# Basic Test Cases
# ----------------------

def test_simple_function_with_return():
    # Basic function with a single return
    src = """
def foo():
    return 1
"""
    func = get_first_function_node(src)
    codeflash_output = function_has_return_statement(func)

def test_simple_function_without_return():
    # Basic function with no return statement
    src = """
def foo():
    x = 5
    y = x + 2
"""
    func = get_first_function_node(src)
    codeflash_output = function_has_return_statement(func)

def test_function_with_return_none():
    # Function with 'return None'
    src = """
def foo():
    return None
"""
    func = get_first_function_node(src)
    codeflash_output = function_has_return_statement(func)

def test_function_with_multiple_returns():
    # Function with multiple return statements
    src = """
def foo(x):
    if x > 0:
        return 1
    else:
        return 2
"""
    func = get_first_function_node(src)
    codeflash_output = function_has_return_statement(func)

def test_function_with_return_in_if():
    # Function with return inside an if block
    src = """
def foo(x):
    if x > 0:
        return x
    y = x + 1
"""
    func = get_first_function_node(src)
    codeflash_output = function_has_return_statement(func)

def test_async_function_with_return():
    # Async function with a return
    src = """
async def foo():
    return 123
"""
    func = get_first_function_node(src)
    codeflash_output = function_has_return_statement(func)

def test_async_function_without_return():
    # Async function without a return
    src = """
async def foo():
    await bar()
"""
    func = get_first_function_node(src)
    codeflash_output = function_has_return_statement(func)

# ----------------------
# Edge Test Cases
# ----------------------

def test_function_with_return_in_nested_function():
    # Return in a nested function should not count for the outer function
    src = """
def foo():
    def bar():
        return 1
    x = 2
"""
    func = get_first_function_node(src)
    codeflash_output = function_has_return_statement(func)

def test_function_with_return_in_class_method():
    # Return in a class method, test detection
    src = """
class C:
    def foo(self):
        return 42
"""
    module = ast.parse(src)
    class_node = next(n for n in module.body if isinstance(n, ast.ClassDef))
    method_node = next(n for n in class_node.body if isinstance(n, ast.FunctionDef))
    codeflash_output = function_has_return_statement(method_node)

def test_function_with_return_in_try_except():
    # Return inside try/except/finally blocks
    src = """
def foo():
    try:
        x = 1
        return x
    except Exception:
        pass
"""
    func = get_first_function_node(src)
    codeflash_output = function_has_return_statement(func)

def test_function_with_return_in_except():
    # Return inside except block
    src = """
def foo():
    try:
        x = 1/0
    except ZeroDivisionError:
        return 0
"""
    func = get_first_function_node(src)
    codeflash_output = function_has_return_statement(func)

def test_function_with_return_in_finally():
    # Return inside finally block
    src = """
def foo():
    try:
        x = 1
    finally:
        return 0
"""
    func = get_first_function_node(src)
    codeflash_output = function_has_return_statement(func)

def test_function_with_return_in_loop():
    # Return inside a loop
    src = """
def foo():
    for i in range(10):
        if i == 5:
            return i
"""
    func = get_first_function_node(src)
    codeflash_output = function_has_return_statement(func)

def test_function_with_return_in_comprehension():
    # Return inside a comprehension is not possible, but test for no false positives
    src = """
def foo():
    x = [i for i in range(10)]
"""
    func = get_first_function_node(src)
    codeflash_output = function_has_return_statement(func)

def test_function_with_only_pass():
    # Function with only 'pass'
    src = """
def foo():
    pass
"""
    func = get_first_function_node(src)
    codeflash_output = function_has_return_statement(func)

def test_function_with_docstring_only():
    # Function with only a docstring
    src = '''
def foo():
    """This is a docstring."""
'''
    func = get_first_function_node(src)
    codeflash_output = function_has_return_statement(func)

def test_function_with_yield_but_no_return():
    # Generator function with yield but no return
    src = """
def foo():
    yield 1
    yield 2
"""
    func = get_first_function_node(src)
    codeflash_output = function_has_return_statement(func)

def test_function_with_return_in_nested_class():
    # Return in a nested class method should not count for the outer function
    src = """
def foo():
    class Bar:
        def baz(self):
            return 1
    x = 2
"""
    func = get_first_function_node(src)
    codeflash_output = function_has_return_statement(func)

def test_function_with_return_in_lambda():
    # Lambda inside function, but no explicit return in function
    src = """
def foo():
    x = lambda y: y + 1
    z = x(3)
"""
    func = get_first_function_node(src)
    codeflash_output = function_has_return_statement(func)

def test_function_with_return_after_comment():
    # Return after comments and blank lines
    src = """
def foo():
    # this is a comment

    return 5
"""
    func = get_first_function_node(src)
    codeflash_output = function_has_return_statement(func)

def test_function_with_return_in_match_case():
    # Return inside a match/case (Python 3.10+)
    src = """
def foo(x):
    match x:
        case 1:
            return "one"
        case _:
            pass
"""
    func = get_first_function_node(src)
    codeflash_output = function_has_return_statement(func)

def test_function_with_return_in_while_else():
    # Return in while-else block
    src = """
def foo(x):
    while x > 0:
        x -= 1
        if x == 2:
            return x
    else:
        return -1
"""
    func = get_first_function_node(src)
    codeflash_output = function_has_return_statement(func)

def test_function_with_return_in_if_elif_else():
    # Return in elif and else branches
    src = """
def foo(x):
    if x == 1:
        pass
    elif x == 2:
        return 2
    else:
        return 3
"""
    func = get_first_function_node(src)
    codeflash_output = function_has_return_statement(func)

def test_function_with_return_in_deeply_nested_blocks():
    # Return in deeply nested blocks
    src = """
def foo(x):
    for i in range(10):
        if i % 2 == 0:
            for j in range(5):
                if j == 3:
                    return i, j
"""
    func = get_first_function_node(src)
    codeflash_output = function_has_return_statement(func)

# ----------------------
# Large Scale Test Cases
# ----------------------

def test_large_function_with_no_return():
    # Large function (many statements), but no return
    body = "\n    ".join([f"x{i} = {i}" for i in range(500)])
    src = f"""
def foo():
    {body}
"""
    func = get_first_function_node(src)
    codeflash_output = function_has_return_statement(func)

def test_large_function_with_return_at_end():
    # Large function with return at the very end
    body = "\n    ".join([f"x{i} = {i}" for i in range(499)])
    src = f"""
def foo():
    {body}
    return 999
"""
    func = get_first_function_node(src)
    codeflash_output = function_has_return_statement(func)

def test_large_function_with_return_in_middle():
    # Large function with return somewhere in the middle
    stmts = [f"x{i} = {i}" for i in range(250)]
    stmts.append("return 123")
    stmts.extend([f"y{i} = {i}" for i in range(250, 500)])
    body = "\n    ".join(stmts)
    src = f"""
def foo():
    {body}
"""
    func = get_first_function_node(src)
    codeflash_output = function_has_return_statement(func)


def test_many_small_functions_some_with_return():
    # Test many small functions, some with and some without return
    src = "\n".join([
        f"def foo{i}():\n    return {i}" if i % 2 == 0 else f"def foo{i}():\n    pass"
        for i in range(20)
    ])
    module = ast.parse(src)
    for i, node in enumerate(module.body):
        if isinstance(node, ast.FunctionDef):
            expected = (i % 2 == 0)
            codeflash_output = function_has_return_statement(node)

def test_function_with_huge_if_else_chain_and_one_return():
    # Function with a long if-elif-else chain, only one branch returns
    chain = ""
    for i in range(999):
        chain += f"    if x == {i}:\n        pass\n"
    chain += f"    else:\n        return 42\n"
    src = f"def foo(x):\n{chain}"
    func = get_first_function_node(src)
    codeflash_output = function_has_return_statement(func)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-function_has_return_statement-mcfagouq and push.

Codeflash

Here's an optimized version of your program focusing on reducing per-node overhead and the significant cost of calling `ast.iter_child_nodes(node)` inside your traversal loop (which accounts for almost **80% of total runtime**).

## Optimization Strategies

- **Inline** the implementation of `ast.iter_child_nodes` instead of calling the function for every node. This saves significant overhead (as the stdlib implementation uses `getattr`, a generator and repeated attribute accesses).
- **Use a deque** for stack to benefit from very fast pops from the right end instead of pop from a Python list.
- **Use local variable lookups** wherever possible (`stack_pop = stack.pop` trick) to avoid repeated attribute access on the hot path.
- Do **not** break the function signature or semantics.

---



---

## Summary of improvements

- **No generator allocation or function call for each child visit:** `fast_iter_child_nodes` is inlined and avoids unnecessary attribute access inside collections.
- **Deque** is much faster for stack pops than a Python list.
- Reused bound methods for stack operations to minimize attribute lookup overhead.
- **No change** in output, preserves semantics, and still stops immediately after first `Return`.

This rewrite will make a significant speed difference especially for large function bodies.
@codeflash-ai codeflash-ai bot added the ⚡️ codeflash Optimization PR opened by Codeflash AI label Jun 27, 2025
@codeflash-ai codeflash-ai bot requested a review from KRRT7 June 27, 2025 20:52
@KRRT7
Copy link
Contributor

KRRT7 commented Jun 27, 2025

Note: this is from a test to find the potential regression on the main branch but it looks like everything is working fine and all our tests including E2E tests pass.

@codeflash-ai codeflash-ai bot deleted the codeflash/optimize-function_has_return_statement-mcfagouq branch June 28, 2025 01:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
⚡️ codeflash Optimization PR opened by Codeflash AI
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants