Skip to content

Commit 9605c50

Browse files
naromero77amdpytorchmergebot
authored andcommitted
[ROCm][TunableOp] Speed-up matmul_small_brute_force_tunableop unit test (pytorch#147659)
This PR has a UT speed-up and some refactoring of tests. A previous PR pytorch#142422 fixed this matmul_small_brute_force_tunableop for the FP16 data type by adding TunableOp numerical checks. It had the unfortunate side effect that it increased the execution time for the FP32 and FP64 data types by a significant margin. This PR *reduces* the execution time by 20+ minutes. We also move a hipBLASLt version check to a different tunableop UT for simplicity. Pull Request resolved: pytorch#147659 Approved by: https://github.com/jeffdaily
1 parent 69c4f6f commit 9605c50

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

test/test_linalg.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4589,7 +4589,8 @@ def test_matmul_small_brute_force_tunableop(self, device, dtype):
45894589
try:
45904590
set_tunableop_defaults()
45914591
torch.cuda.tunable.set_rotating_buffer_size(0)
4592-
os.environ["PYTORCH_TUNABLEOP_NUMERICAL_CHECK"] = "1"
4592+
if dtype is torch.half:
4593+
os.environ["PYTORCH_TUNABLEOP_NUMERICAL_CHECK"] = "1"
45934594
ordinal = torch.cuda.current_device()
45944595
torch.cuda.tunable.set_filename(f"tunableop_results{ordinal}.csv")
45954596

@@ -4610,10 +4611,6 @@ def test_matmul_small_brute_force_tunableop(self, device, dtype):
46104611
filename3 = "tunableop_results_tmp2.csv"
46114612
ordinal = torch.cuda.current_device()
46124613
assert filename1 == f"tunableop_results{ordinal}.csv"
4613-
validators = get_tunableop_validators()
4614-
if torch.version.hip:
4615-
assert "HIPBLASLT_VERSION" in validators
4616-
assert re.match(r'^\d+-[a-z0-9]+$', validators["HIPBLASLT_VERSION"])
46174614
assert len(torch.cuda.tunable.get_results()) > 0
46184615

46194616
assert torch.cuda.tunable.write_file() # use default filename
@@ -4953,9 +4950,12 @@ def test_validator_tunableop_rocm(self, device, dtype):
49534950
self.assertEqual(len(torch.cuda.tunable.get_validators()), validator_num_lines)
49544951

49554952
validators = get_tunableop_validators()
4953+
# Check for rocBLAS and hipBLASLt
49564954
self.assertTrue("ROCBLAS_VERSION" in validators)
49574955
# format: [major].[minor].[patch].[tweak].[commit id]
49584956
self.assertTrue(re.match(r'^\d+.\d+.\d+.\d+.[a-z0-9]+$', validators["ROCBLAS_VERSION"]))
4957+
self.assertTrue("HIPBLASLT_VERSION" in validators)
4958+
self.assertTrue(re.match(r'^\d+-[a-z0-9]+$', validators["HIPBLASLT_VERSION"]))
49594959

49604960
# disable TunableOp
49614961
torch.cuda.tunable.enable(False)

0 commit comments

Comments
 (0)