Skip to content

Commit 9691959

Browse files
authored
[TRITON] Unit test fixes to la tests and gemm afp4wfp4 (ROCm#1494)
* Fix: Add debug mode for mismatch reporting in lean attention tests * Fix: Adjust x_scales_shuffled slicing according to x_scales size
1 parent 006fe31 commit 9691959

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

op_tests/triton_tests/test_gemm_afp4wfp4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def generate_gemm_afp4wfp4_inputs(
122122
w_shuffed,
123123
x_scales[:M],
124124
w_scales,
125-
x_scales_shuffled,
125+
x_scales_shuffled[:M],
126126
w_scales_shuffled,
127127
out_dtype,
128128
y,

op_tests/triton_tests/test_la.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import aiter.ops.triton.utils._triton.arch_info as arch_info
1515
import pytest
1616

17+
DEBUG_MODE = False
18+
1719

1820
def get_lean_attn_inputs(
1921
batch: int,
@@ -422,8 +424,9 @@ def test_persistent_lean_attention(
422424
try:
423425
torch.testing.assert_close(ref_out, la_out, atol=atol, rtol=rtol)
424426
except AssertionError:
425-
print("Assertion failed! Showing mismatches:")
426-
print_mismatches(ref_out, la_out, atol, rtol)
427+
if DEBUG_MODE:
428+
print("Assertion failed! Showing mismatches:")
429+
print_mismatches(ref_out, la_out, atol, rtol)
427430
raise # Re-raise the exception after printing mismatches
428431

429432

0 commit comments

Comments
 (0)