Open
Description
Hello, I am trying to use attention variant with a custom mask. The JIT arguments are set in the following way:
variant_decl = r"""
struct FlashCustomMask : AttentionVariantBase {
static constexpr bool use_softmax = true;
uint8_t* custom_mask_ptr;
uint32_t qo_len, kv_len;
float sm_scale_log2; //*
uint32_t window_left; //*
template <typename Params>
__device__ __host__ FlashCustomMask(const Params& params, uint32_t batch_idx,
uint8_t* smem_ptr) {
qo_len = params.get_qo_len(batch_idx);
kv_len = params.get_kv_len(batch_idx);
custom_mask_ptr = params.maybe_custom_mask + params.maybe_mask_indptr[batch_idx];
sm_scale_log2 = math::log2e;
}
REGISTER_LOGITS_MASK(params, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, {
bool mask = true;
const uint32_t offset = qo_idx * kv_len + kv_idx;
mask &= ((custom_mask_ptr[offset / 8] >> (offset % 8)) & 1); //* load mask
return mask;
})
};
"""
# ...
jit_args = (
"batch_prefill_flash_custom_mask", # uri
torch.float16, # dtype_q
torch.float16, # dtype_kv
torch.float16, # dtype_o
torch.int32, # idtype
head_dim, # hidden_dim_qk
head_dim, # hidden_dim_vo
['maybe_custom_mask', 'maybe_mask_indptr'], # additional_tensor_names, e.g., ['maybe_custom_mask', 'maybe_mask_indptr', 'maybe_alibi_slopes']
['uint8_t', 'int32_t'], # additional_tensor_dtypes
[], # additional_scalar_names,
[], # additional_scalar_dtypes
"FlashCustomMask",
variant_decl,
)
# ...
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout="NHD", backend="fa2", jit_args=jit_args
)
# ...
o_custom = prefill_wrapper.run(q, kv_cache)
When I directly run the code, an error is reported:
RuntimeError: batch_prefill_flash_custom_mask::paged_run() is missing value for argument '_15'. Declaration: batch_prefill_flash_custom_mask::paged_run(Tensor _0, Tensor _1, Tensor _2, Tensor _3, Tensor _4, Tensor _5, Tensor _6, Tensor _7, Tensor _8, Tensor _9, Tensor _10, Tensor? _11, int _12, int _13, int _14, Tensor? _15, Tensor? _16) -> ()
After digging into BatchPrefillWithPagedKVCacheWrapper
, I locate the problem in flashinfer/prefill.py#L1660:
run_args = [
self._float_workspace_buffer,
self._int_workspace_buffer,
self._plan_info,
q,
k_cache,
v_cache,
self._qo_indptr_buf,
sparse_indptr,
sparse_indices,
self._paged_kv_last_page_len_buf,
out,
lse,
mask_mode,
TensorLayout[self._kv_layout].value,
window_left,
]
if self._jit_module is not None: # when using attention variants
run_args.extend(list(args))
else:
run_args += [
self._custom_mask_buf,
self._mask_indptr_buf,
_get_cache_alibi_slopes_buf(q.shape[1], q.device),
logits_soft_cap,
sm_scale,
rope_scale,
rope_theta,
]
self._cached_module.paged_run(*run_args)
It means that when I use attention variant, I have to provide args
when calling run()
. In my case, they are _custom_mask_buf
and _mask_indptr_buf
, both of which are protected members.
I can temporarily fix this problem by using
o_custom = prefill_wrapper.run(q, kv_cache, prefill_wrapper._custom_mask_buf, prefill_wrapper._mask_indptr_buf)
Metadata
Metadata
Assignees
Labels
No labels