-
Notifications
You must be signed in to change notification settings - Fork 83
Open
Description
Describe the bug
Below example where each config writes first and second half of the result, when running with single config there's expected assert about 50% mismatched results, but when running with both configs (REPRO=1) it passes which may lead to some silent issues.
import os
import torch
import triton
import triton.language as tl
def get_configs():
configs = [
triton.Config({'OFFSET': 0, 'ELEMENTS': 512}),
triton.Config({'OFFSET': 512, 'ELEMENTS': 512})
]
repro = int(os.environ.get('REPRO', '0'))
return configs if repro else configs[:1]
@triton.autotune(
configs=get_configs(),
key=[],
)
@triton.jit
def add_kernel(
x_ptr,
y_ptr,
output_ptr,
OFFSET: tl.constexpr,
ELEMENTS: tl.constexpr
):
offsets = tl.arange(OFFSET, ELEMENTS + OFFSET)
x = tl.load(x_ptr + offsets)
y = tl.load(y_ptr + offsets)
output = x + y
tl.store(output_ptr + offsets, output)
def get_triton_result():
x = torch.ones(1024, device='xpu')
y = torch.ones(1024, device='xpu')
output = torch.empty(1024, device='xpu')
add_kernel[(1,)](x, y, output)
return output
def get_torch_result():
x = torch.ones(1024, device='xpu')
y = torch.ones(1024, device='xpu')
output = x + y
return output
triton_output = get_triton_result()
torch_output = get_torch_result()
torch.testing.assert_close(triton_output, torch_output)Environment details
Triton: latest