Skip to content

Problem when using attention variant with custom mask #1044

Open
@xiaozxiong

Description

@xiaozxiong

Hello, I am trying to use attention variant with a custom mask. The JIT arguments are set in the following way:

variant_decl = r"""
struct FlashCustomMask : AttentionVariantBase {
  static constexpr bool use_softmax = true;

  uint8_t* custom_mask_ptr;
  uint32_t qo_len, kv_len;
  float sm_scale_log2; //*
  uint32_t window_left; //*

  template <typename Params>
  __device__ __host__ FlashCustomMask(const Params& params, uint32_t batch_idx,
                                   uint8_t* smem_ptr) {
    qo_len = params.get_qo_len(batch_idx);
    kv_len = params.get_kv_len(batch_idx);
    custom_mask_ptr = params.maybe_custom_mask + params.maybe_mask_indptr[batch_idx];
    sm_scale_log2 = math::log2e;
  }

  REGISTER_LOGITS_MASK(params, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, {
    bool mask = true;

    const uint32_t offset = qo_idx * kv_len + kv_idx;
    mask &= ((custom_mask_ptr[offset / 8] >> (offset % 8)) & 1); //* load mask
    
    return mask;
  })
};
"""
    # ...
    jit_args = (
        "batch_prefill_flash_custom_mask",  # uri
        torch.float16,  # dtype_q
        torch.float16,  # dtype_kv
        torch.float16,  # dtype_o
        torch.int32,  # idtype
        head_dim,  # hidden_dim_qk
        head_dim,  # hidden_dim_vo
        ['maybe_custom_mask', 'maybe_mask_indptr'],  # additional_tensor_names, e.g., ['maybe_custom_mask', 'maybe_mask_indptr', 'maybe_alibi_slopes']
        ['uint8_t', 'int32_t'],  # additional_tensor_dtypes
        [],  # additional_scalar_names,
        [],  # additional_scalar_dtypes
        "FlashCustomMask",
        variant_decl,
    )
    # ...
    workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
    prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
        workspace_buffer, kv_layout="NHD", backend="fa2", jit_args=jit_args
    )
    #  ...
    o_custom = prefill_wrapper.run(q, kv_cache)

When I directly run the code, an error is reported:

RuntimeError: batch_prefill_flash_custom_mask::paged_run() is missing value for argument '_15'. Declaration: batch_prefill_flash_custom_mask::paged_run(Tensor _0, Tensor _1, Tensor _2, Tensor _3, Tensor _4, Tensor _5, Tensor _6, Tensor _7, Tensor _8, Tensor _9, Tensor _10, Tensor? _11, int _12, int _13, int _14, Tensor? _15, Tensor? _16) -> ()

After digging into BatchPrefillWithPagedKVCacheWrapper, I locate the problem in flashinfer/prefill.py#L1660:

run_args = [
            self._float_workspace_buffer,
            self._int_workspace_buffer,
            self._plan_info,
            q,
            k_cache,
            v_cache,
            self._qo_indptr_buf,
            sparse_indptr,
            sparse_indices,
            self._paged_kv_last_page_len_buf,
            out,
            lse,
            mask_mode,
            TensorLayout[self._kv_layout].value,
            window_left,
        ]

        if self._jit_module is not None: # when using attention variants
            run_args.extend(list(args))
        else:
            run_args += [
                self._custom_mask_buf,
                self._mask_indptr_buf,
                _get_cache_alibi_slopes_buf(q.shape[1], q.device),
                logits_soft_cap,
                sm_scale,
                rope_scale,
                rope_theta,
            ]

        self._cached_module.paged_run(*run_args)

It means that when I use attention variant, I have to provide args when calling run(). In my case, they are _custom_mask_buf and _mask_indptr_buf, both of which are protected members.
I can temporarily fix this problem by using

o_custom = prefill_wrapper.run(q, kv_cache, prefill_wrapper._custom_mask_buf, prefill_wrapper._mask_indptr_buf)

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions