Skip to content

Add CUDA 8.7 arch for Jetson Orin #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ set(ignoreMe "${VLLM_PYTHON_PATH}")
set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12" "3.13")

# Supported NVIDIA architectures.
set(CUDA_SUPPORTED_ARCHS "8.0;8.6;8.9;9.0")
set(CUDA_SUPPORTED_ARCHS "8.0;8.6;8.7;8.9;9.0")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
list(APPEND CUDA_SUPPORTED_ARCHS "10.0" "10.1" "12.0")
endif()
Expand Down Expand Up @@ -135,7 +135,7 @@ if (FA2_ENABLED)

# For CUDA we set the architectures on a per file basis
if (VLLM_GPU_LANG STREQUAL "CUDA")
cuda_archs_loose_intersection(FA2_ARCHS "8.0;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
cuda_archs_loose_intersection(FA2_ARCHS "8.0;8.7;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
message(STATUS "FA2_ARCHS: ${FA2_ARCHS}")

set_gencode_flags_for_srcs(
Expand Down
4 changes: 2 additions & 2 deletions csrc/ft_attention/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def append_nvcc_threads(nvcc_extra_args):
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None:
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version >= Version("11.8"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0"
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;8.7;9.0"
elif bare_metal_version >= Version("11.1"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;8.7"
elif bare_metal_version == Version("11.0"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
else:
Expand Down
4 changes: 2 additions & 2 deletions csrc/layer_norm/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ def append_nvcc_threads(nvcc_extra_args):
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None:
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version >= Version("11.8"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0"
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;8.7;9.0"
elif bare_metal_version >= Version("11.1"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;8.7"
elif bare_metal_version == Version("11.0"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
else:
Expand Down
4 changes: 2 additions & 2 deletions csrc/rotary/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ def append_nvcc_threads(nvcc_extra_args):
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None:
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version >= Version("11.8"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0"
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;8.7;9.0"
elif bare_metal_version >= Version("11.1"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;8.7"
elif bare_metal_version == Version("11.0"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
else:
Expand Down
4 changes: 2 additions & 2 deletions csrc/xentropy/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ def append_nvcc_threads(nvcc_extra_args):
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None:
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version >= Version("11.8"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0"
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;8.7;9.0"
elif bare_metal_version >= Version("11.1"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;8.7"
elif bare_metal_version == Version("11.0"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
else:
Expand Down
8 changes: 4 additions & 4 deletions hopper/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ inline int get_num_splits(Flash_fwd_params const& params) {
auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f, use_one_mma_wg(params));
// Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits
// has not been set here. It's OK though because we might just underestimate kBlockN a bit
auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr);
auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 87 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr);
int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x);
int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x);
int seqlen_q_packgqa = params.seqlen_q * (params.h / params.h_k);
Expand Down Expand Up @@ -614,7 +614,7 @@ mha_fwd_get_scheduler_metadata(

if (params.num_splits_dynamic_ptr) {
auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f, use_one_mma_wg(params));
auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr);
auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 87 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr);
int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x);
int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x);
auto stream = at::cuda::getCurrentCUDAStream().stream();
Expand Down Expand Up @@ -1304,7 +1304,7 @@ std::vector<at::Tensor> mha_bwd(
: 64));
int const kBlockM_sm80 = head_size_rounded <= 64 ? 128 : 64;
int const kBlockM_sm86 = head_size_rounded <= 192 ? 64 : 32;
int const kBlockM = arch >= 90 ? kBlockM_sm90 : (arch == 86 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80);
int const kBlockM = arch >= 90 ? kBlockM_sm90 : (arch == 86 || arch == 87 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80);
int const kBlockN_sm90 = head_size_rounded <= 128
? 128
: (head_size_rounded <= 192 ? 96 : 80);
Expand All @@ -1315,7 +1315,7 @@ std::vector<at::Tensor> mha_bwd(
: (head_size_rounded <= 96 ? 128
: (head_size_rounded <= 128 ? 96
: (head_size_rounded <= 192 ? 64 : 64)));
int const kBlockN = arch >= 90 ? kBlockN_sm90 : (arch == 86 || arch == 89 ? kBlockN_sm86 : kBlockN_sm80);
int const kBlockN = arch >= 90 ? kBlockN_sm90 : (arch == 86 || arch == 87 || arch == 89 ? kBlockN_sm86 : kBlockN_sm80);
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
int const seqlen_q_rounded = round_multiple(seqlen_q, kBlockM);
int const seqlen_k_rounded = round_multiple(seqlen_k, kBlockN);
Expand Down
10 changes: 5 additions & 5 deletions hopper/flash_bwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
// With ShuffleStats we no longer have register spilling when Has_softcap and using 128 x 128 block.
run_mha_bwd_dispatch<Arch, T, 128, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 2, false>(params, stream);
}
} else if constexpr (Arch == 86 || Arch == 89) {
} else if constexpr (Arch == 86 || Arch == 87 || Arch == 89) {
run_mha_bwd_dispatch<Arch, T, 64, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 4, 2, true>(params, stream);
// run_mha_bwd_dispatch<Arch, T, 96, 96, 64, Is_causal, Is_local, Has_softcap, 1, 2, false, true, true, 2, 2, 4, 4, false>(params, stream);
// run_mha_bwd_dispatch<Arch, T, 80, 128, 64, Is_causal, Is_local, Has_softcap, 1, 2, true, false, true, 2, 2, 4, 2, true>(params, stream);
Expand All @@ -324,7 +324,7 @@ void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
if constexpr (Arch >= 90) {
run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 1, true>(params, stream);
} else if constexpr (Arch == 86 || Arch == 89) {
} else if constexpr (Arch == 86 || Arch == 87 || Arch == 89) {
run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 1, 2, false, false, false, 2, 2, 4, 2, true>(params, stream);
} else {
run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 4, 2, false>(params, stream);
Expand All @@ -341,7 +341,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
} else {
run_mha_bwd_dispatch<Arch, T, 80, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, true, false, true, 2, 1, 2, 1, false>(params, stream);
}
} else if constexpr (Arch == 86 || Arch == 89) {
} else if constexpr (Arch == 86 || Arch == 87 || Arch == 89) {
run_mha_bwd_dispatch<Arch, T, 64, 96, 128, Is_causal, Is_local, Has_softcap, 1, 2, false, false, false, 2, 2, 2, 2, true>(params, stream);
} else {
run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 2, 2, false>(params, stream);
Expand All @@ -354,7 +354,7 @@ void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
if constexpr (Arch >= 90) {
run_mha_bwd_dispatch<Arch, T, 64, 96, 192, Is_causal, Is_local, Has_softcap, 1, 1, false, true, false, 3, 1, 1, 1, false>(params, stream);
} else if constexpr (Arch == 86 || Arch == 89) {
} else if constexpr (Arch == 86 || Arch == 87 || Arch == 89) {
run_mha_bwd_dispatch<Arch, T, 64, 64, 192, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 2, 2, 2, true>(params, stream);
} else {
run_mha_bwd_dispatch<Arch, T, 64, 80, 192, Is_causal, Is_local, Has_softcap, 1, 2, false, true, false, 2, 4, 2, 2, false>(params, stream);
Expand All @@ -367,7 +367,7 @@ void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
if constexpr (Arch >= 90) {
run_mha_bwd_dispatch<Arch, T, 64, 80, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, true, true, 2, 1, 1, 1, false>(params, stream);
} else if constexpr (Arch == 86 || Arch == 89) {
} else if constexpr (Arch == 86 || Arch == 87 || Arch == 89) {
run_mha_bwd_dispatch<Arch, T, 32, 64, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 2, 2, 1, true>(params, stream);
// run_mha_bwd_dispatch<Arch, T, 64, 32, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 4, 1, 2, true>(params, stream);
} else {
Expand Down
2 changes: 1 addition & 1 deletion hopper/flash_fwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {

// Can't use structured binding since it's not compatible with constexpr
static constexpr std::tuple<int, int, bool, bool> kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap, Use_one_mma_wg);
static constexpr std::tuple<int, int, int, int, bool> kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKVNonTMA, Varlen && Split, Has_softcap, AppendKV);
static constexpr std::tuple<int, int, int, int, bool> kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 87 || Arch == 89, kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKVNonTMA, Varlen && Split, Has_softcap, AppendKV);
static constexpr int kBlockM = Arch >= 90 ? std::get<0>(kBlockMN_RS_IntraWGOverlap) : std::get<0>(kBlockMN_kNWarps_Stages_RS);
static constexpr int kBlockN = Arch >= 90 ? std::get<1>(kBlockMN_RS_IntraWGOverlap) : std::get<1>(kBlockMN_kNWarps_Stages_RS);
static constexpr bool MmaPV_is_RS = std::get<2>(kBlockMN_RS_IntraWGOverlap);
Expand Down
2 changes: 1 addition & 1 deletion hopper/static_switch.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@
#else
#define ARCH_SWITCH(ARCH, ARCH_NAME, ...) \
[&] { \
if (ARCH == 86 || ARCH == 89) { \
if (ARCH == 86 || ARCH == 87 || ARCH == 89) { \
constexpr static int ARCH_NAME = 86; \
return __VA_ARGS__(); \
} else if (ARCH < 90) { \
Expand Down
3 changes: 2 additions & 1 deletion vllm_flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,11 @@ def _is_fa3_supported(device = None) -> Tuple[bool, Optional[str]]:
if torch.cuda.get_device_capability(device)[0] < 8 \
or torch.cuda.get_device_capability(device)[0] >= 10 \
or torch.cuda.get_device_capability(device) == (8, 6) \
or torch.cuda.get_device_capability(device) == (8, 7) \
or torch.cuda.get_device_capability(device) == (8, 9):
return False, \
"FA3 is only supported on devices with compute capability >= 8" \
" excluding 8.6 and 8.9 and Blackwell archs (>=10)"
" excluding 8.6, 8.7 and 8.9 and Blackwell archs (>=10)"
return True, None

def is_fa_version_supported(fa_version: int, device = None) -> bool:
Expand Down