Skip to content

Commit ed3605b

Browse files
committed
Fix issue in recompiling kernel with double GRF mode.
Signed-off-by: Lu,Chengjun <[email protected]>
1 parent a4bde41 commit ed3605b

File tree

2 files changed

+82
-6
lines changed

2 files changed

+82
-6
lines changed

python/test/unit/intel/test_driver.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import re
2+
import tempfile
3+
import subprocess
4+
import sys
5+
import os
6+
7+
8+
def test_auto_grf():
9+
10+
test_code = """
11+
import numpy as np
12+
import torch
13+
import triton
14+
import triton.language as tl
15+
16+
from triton._internal_testing import to_numpy
17+
18+
19+
def test_auto_grf(device):
20+
BLOCK = 1024 * 8
21+
z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
22+
23+
@triton.jit
24+
def _kernel(z, BLOCK: tl.constexpr):
25+
# make it hard to re-schedule.
26+
off = tl.arange(0, BLOCK)
27+
a = tl.load(z + off)
28+
result = tl.sum(a, axis=0, keep_dims=True)
29+
tl.store(z + off, a + result)
30+
31+
_kernel[(1, )](z_tri, BLOCK=BLOCK, num_warps=2)
32+
z_ref = torch.arange(0, BLOCK, dtype=torch.int32, device=device)
33+
34+
test_auto_grf("xpu")
35+
"""
36+
37+
with (tempfile.NamedTemporaryFile(mode='w', suffix='.py') as f):
38+
f.write(test_code)
39+
f.flush()
40+
env = os.environ.copy()
41+
env["TRITON_DEBUG"] = "1"
42+
proc = subprocess.run(
43+
[sys.executable, f.name],
44+
capture_output=True,
45+
env=env,
46+
)
47+
assert proc.returncode == 0
48+
outs = [line for line in proc.stdout.decode("UTF-8").splitlines() if line]
49+
# The output should contain the recompiling information for large GRF mode.
50+
assert re.search(r"recompiling the kernel using large GRF mode", outs[0])
51+
# The spill size of returned kernel should be same kernel as the one compiled with large GRF mode.
52+
assert re.findall(r"\d+\.?\d*", outs[1])[0] == re.findall(r"\d+\.?\d*", outs[2])[0]

third_party/intel/backend/driver.c

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -227,15 +227,15 @@ extern "C" EXPORT_FUNC PyObject *load_binary(PyObject *args) {
227227
compileLevelZeroObjects(binary_ptr, binary_size, kernel_name, l0_device,
228228
l0_context, build_flags(), is_spv);
229229

230+
const bool debugEnabled = getBoolEnv("TRITON_DEBUG");
231+
230232
if (is_spv) {
231233
constexpr int32_t max_reg_spill = 1000;
232234
const bool is_GRF_mode_specified = build_flags.hasGRFSizeFlag();
233235

234236
// If the register mode isn't set, and the number of spills is greater
235237
// than the threshold, recompile the kernel using large GRF mode.
236238
if (!is_GRF_mode_specified && n_spills > max_reg_spill) {
237-
const std::optional<bool> debugEnabled =
238-
isEnvValueBool(getStrEnv("TRITON_DEBUG"));
239239
if (debugEnabled)
240240
std::cout << "(I): Detected " << n_spills
241241
<< " spills, recompiling the kernel using large GRF mode"
@@ -244,13 +244,32 @@ extern "C" EXPORT_FUNC PyObject *load_binary(PyObject *args) {
244244
build_flags.addLargeGRFSizeFlag();
245245

246246
try {
247-
auto [l0_module, l0_kernel, n_spills] = compileLevelZeroObjects(
248-
binary_ptr, binary_size, kernel_name, l0_device, l0_context,
249-
build_flags(), is_spv);
247+
auto [l0_module_dgrf, l0_kernel_dgrf, n_spills_dgrf] =
248+
compileLevelZeroObjects(binary_ptr, binary_size, kernel_name,
249+
l0_device, l0_context, build_flags(),
250+
is_spv);
250251

251252
if (debugEnabled)
252-
std::cout << "(I): Kernel has now " << n_spills << " spills"
253+
std::cout << "(I): Kernel has now " << n_spills_dgrf << " spills"
253254
<< std::endl;
255+
if (n_spills_dgrf < n_spills) {
256+
std::swap(l0_module, l0_module_dgrf);
257+
std::swap(l0_kernel, l0_kernel_dgrf);
258+
std::swap(n_spills, n_spills_dgrf);
259+
}
260+
// clean up the unused module and kernel.
261+
auto error_no = zeKernelDestroy(l0_kernel_dgrf);
262+
if (error_no != ZE_RESULT_SUCCESS) {
263+
std::cerr
264+
<< "[Ignoring] Intel - Error during destroy unused L0 kernel"
265+
<< std::endl;
266+
}
267+
error_no = zeModuleDestroy(l0_module_dgrf);
268+
if (error_no != ZE_RESULT_SUCCESS) {
269+
std::cerr
270+
<< "[Ignoring] Intel - Error during destroy unused L0 module"
271+
<< std::endl;
272+
}
254273
} catch (const std::exception &e) {
255274
std::cerr << "[Ignoring] Error during Intel loadBinary with large "
256275
"registers: "
@@ -261,6 +280,11 @@ extern "C" EXPORT_FUNC PyObject *load_binary(PyObject *args) {
261280
}
262281
}
263282

283+
if (debugEnabled && n_spills) {
284+
std::cout << "(I): Detected " << n_spills << " spills for \""
285+
<< kernel_name << "\"" << std::endl;
286+
}
287+
264288
auto n_regs = build_flags.n_regs();
265289

266290
auto mod = new sycl::kernel_bundle<sycl::bundle_state::executable>(

0 commit comments

Comments
 (0)