From fcee96dde81b5deafd61a57462763a2c898ffe87 Mon Sep 17 00:00:00 2001 From: zhanyuan Date: Mon, 25 May 2026 13:21:38 +0000 Subject: [PATCH] Adjust build config and add benchmarks - Change MHA_NUM_JOBS calculation to round up - Set default HDIM=128, DTYPE=BF16 - Change make parallel jobs to fixed value 6 - Add .gitignore and benchmarks directory Co-Authored-By: Claude Opus 4.7 --- .gitignore | 12 ++ CMakeLists.txt | 97 ++++++++--- Makefile | 116 ++++++++++--- README_MX.md | 83 ++++++++- benchmarks/benchmark_kvcache.py | 146 ++++++++++++++++ cmake/build_flash_attn.cmake | 164 +++++++++++++++++- csrc/flash_attn/flash_api/flash_api_bwd.cpp | 57 ++++++ csrc/flash_attn/flash_api/flash_api_fwd.cpp | 12 ++ .../flash_api/flash_api_fwd_kvcache.cpp | 12 ++ .../flash_fwd_dispatch_template.h | 4 + csrc/flash_attn/flash_run/run_mha_bwd.cpp | 8 + setup.py | 26 ++- test_jobs.sh | 5 + tools/build_scripts/build_projects_related.sh | 0 14 files changed, 679 insertions(+), 63 deletions(-) create mode 100644 .gitignore create mode 100644 benchmarks/benchmark_kvcache.py create mode 100644 test_jobs.sh mode change 100644 => 100755 tools/build_scripts/build_projects_related.sh diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..34f8c6e --- /dev/null +++ b/.gitignore @@ -0,0 +1,12 @@ +build_kernel/ +build_kernel_* +build/ +flash_attn.egg-info/ +dist/ +*.egg +*.egg-info +__pycache__/ +*.so +*.dylib +*.pyd +*.log diff --git a/CMakeLists.txt b/CMakeLists.txt index aab24b5..ff5d1e1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -54,13 +54,27 @@ endif() option(BUILD_WITH_CPP "Build cpp library" true) option(BUILD_WITH_KERNEL "Build kernel" true) -option(FAST_BUILD "Fast build" OFF) -option(GEN_KERNEL "Generate kernel" OFF) +option(BUILD_WITH_HOST "Build host static library" true) +option(BUILD_WITH_BWD_KERNEL "Build backward kernels" OFF) +option(FWD_ENABLE_LOCAL "Build forward local-attention kernel variants" OFF) +option(FWD_ENABLE_ALIBI "Build forward ALiBi kernel variants" OFF) +option(FWD_ENABLE_SOFTCAP "Build forward softcap kernel variants" OFF) +option(FWD_ENABLE_APPENDKV "Build forward append-KV kernel variants" OFF) +option(FWD_ENABLE_CAUSAL "Build forward causal kernel variants" OFF) +set(FWD_MN_LIST "DEFAULT" CACHE STRING "Comma-separated xcore1000 fwd tile list, or DEFAULT for the default dispatch tiles") +set(FWD_SPLIT_MN_LIST "DEFAULT" CACHE STRING "Comma-separated xcore1000 fwd_split tile list, or DEFAULT for the default dispatch tiles") message(STATUS "BUILD_WITH_CPP:${BUILD_WITH_CPP}") message(STATUS "BUILD_WITH_KERNEL:${BUILD_WITH_KERNEL}") -message(STATUS "FAST_BUILD:${FAST_BUILD}") -message(STATUS "GEN_KERNEL:${GEN_KERNEL}") +message(STATUS "BUILD_WITH_HOST:${BUILD_WITH_HOST}") +message(STATUS "BUILD_WITH_BWD_KERNEL:${BUILD_WITH_BWD_KERNEL}") +message(STATUS "FWD_MN_LIST:${FWD_MN_LIST}") +message(STATUS "FWD_SPLIT_MN_LIST:${FWD_SPLIT_MN_LIST}") +message(STATUS "FWD_ENABLE_LOCAL:${FWD_ENABLE_LOCAL}") +message(STATUS "FWD_ENABLE_ALIBI:${FWD_ENABLE_ALIBI}") +message(STATUS "FWD_ENABLE_SOFTCAP:${FWD_ENABLE_SOFTCAP}") +message(STATUS "FWD_ENABLE_APPENDKV:${FWD_ENABLE_APPENDKV}") +message(STATUS "FWD_ENABLE_CAUSAL:${FWD_ENABLE_CAUSAL}") # set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${CMAKE_COMMAND} -E time") # set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "../compile_time.sh") @@ -99,22 +113,52 @@ add_compile_definitions(USE_MACA) add_compile_definitions(NV_ARCH_A100) add_compile_definitions(__FAST_HALF_CVT__) add_compile_definitions(__MERGE_LDS_B64) - -if(NOT DEFINED HDIM) - set(HDIM 0) +if(NOT BUILD_WITH_BWD_KERNEL) + add_compile_definitions(FLASHATTENTION_DISABLE_BACKWARD) + add_compile_definitions(FLASHATTENTION_DISABLE_DROPOUT) + add_compile_definitions(DROPOUT_FALSE) +endif() +if(NOT FWD_ENABLE_LOCAL) + add_compile_definitions(FLASHATTENTION_DISABLE_LOCAL) + add_compile_definitions(LOCAL_FALSE) +endif() +if(NOT FWD_ENABLE_ALIBI) + add_compile_definitions(FLASHATTENTION_DISABLE_ALIBI) + add_compile_definitions(ALIBI_FALSE) +endif() +if(NOT FWD_ENABLE_SOFTCAP) + add_compile_definitions(FLASHATTENTION_DISABLE_SOFTCAP) + add_compile_definitions(SOFTCAP_FALSE) +endif() +if(NOT FWD_ENABLE_APPENDKV) + add_compile_definitions(FLASHATTENTION_DISABLE_APPENDKV) + add_compile_definitions(APPENDKV_FALSE) +endif() +if(NOT FWD_ENABLE_CAUSAL) + add_compile_definitions(FLASHATTENTION_DISABLE_CAUSAL) + add_compile_definitions(CAUSAL_FALSE) endif() if(NOT DEFINED FA_TYPE) set(FA_TYPE "ALL") endif() -message(STATUS "HDIM=${HDIM},FA_TYPE=${FA_TYPE}") - -STRING(TOUPPER "${FA_TYPE}" FA_TYPE_UPPER) -if(${HDIM} STREQUAL "0" AND NOT ${FA_TYPE_UPPER} STREQUAL "ALL") - MESSAGE(STATUS "Setting DTYPE to ALL because HDIM is 0") - set(FA_TYPE_UPPER "ALL") +if(NOT DEFINED HDIM) + if(DEFINED HDIM_CONFIG_LIST AND NOT "${HDIM_CONFIG_LIST}" STREQUAL "") + # Internal multi-hdim sentinel for source selection. Keep the user-facing + # configure log explicit so it does not look like hdim 0 is being built. + set(HDIM 0) + set(HDIM_STATUS "HDIM_CONFIG_LIST=${HDIM_CONFIG_LIST}") + else() + set(HDIM 0) + set(HDIM_STATUS "HDIM=ALL") + endif() +else() + set(HDIM_STATUS "HDIM=${HDIM}") endif() +message(STATUS "${HDIM_STATUS},FA_TYPE=${FA_TYPE}") + +STRING(TOUPPER "${FA_TYPE}" FA_TYPE_UPPER) if (${FA_TYPE_UPPER} STREQUAL "ALL") add_compile_definitions(FA_DTYPE_ALL) @@ -130,7 +174,16 @@ if(NOT HDIM STREQUAL "0") add_compile_definitions(HDIM_CONFIG=${HDIM}) add_compile_definitions(HDIM_${HDIM}) else() - add_compile_definitions(HDIM_ALL) + if(DEFINED HDIM_CONFIG_LIST AND NOT "${HDIM_CONFIG_LIST}" STREQUAL "") + add_compile_definitions("HDIM_CONFIG=${HDIM_CONFIG_LIST}") + string(REPLACE "," ";" HDIM_CONFIG_ITEMS "${HDIM_CONFIG_LIST}") + foreach(HDIM_CONFIG_ITEM IN LISTS HDIM_CONFIG_ITEMS) + string(STRIP "${HDIM_CONFIG_ITEM}" HDIM_CONFIG_ITEM) + add_compile_definitions(HDIM_${HDIM_CONFIG_ITEM}) + endforeach() + else() + add_compile_definitions(HDIM_ALL) + endif() endif() include_directories( @@ -158,21 +211,9 @@ if (BUILD_WITH_KERNEL) STRING(TOLOWER "${FA_TYPE}" FA_TYPE_LOWER) - if(FAST_BUILD AND GEN_KERNEL) - execute_process( - COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/run_generator.sh capi ${HDIM} ${FA_TYPE_LOWER} ${MACA_ARCH} - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} - RESULT_VARIABLE result - OUTPUT_VARIABLE output - ERROR_VARIABLE error - ) - - if(NOT result EQUAL 0) - message(ERROR ": Generate kernel files failed, output:\n${error}") - endif() + if(BUILD_WITH_HOST) + build_flash_attn_host(mcFlashAttnHostStatic) endif() - - build_flash_attn_host(mcFlashAttnHostStatic) if(${MACA_ARCH} STREQUAL "xcore1000") build_flash_attn_kernel(mcFlashAttnKernelXcore1000Static "xcore1000" ${HDIM} ${FA_TYPE}) target_compile_options(mcFlashAttnKernelXcore1000Static PRIVATE --offload-arch=xcore1000) diff --git a/Makefile b/Makefile index ec8b22f..65ef664 100644 --- a/Makefile +++ b/Makefile @@ -1,36 +1,96 @@ -CURRENT_CPU_NUM:=$(shell grep -c processor /proc/cpuinfo) +CURRENT_CPU_NUM:=$(shell grep "cpu cores" /proc/cpuinfo | head -1 | awk '{print $$4}') # If in maca build, restrict num_jobs to 140 to resolve x86 docker which using 384 cores leads to OOM ifdef BUILDROOT # In x86 docker maca building, "make -j112" is a stable parallel number, so we use 140 here to get 140*0.8=112 CURRENT_CPU_NUM:=$(shell awk -v a=$(CURRENT_CPU_NUM) 'BEGIN {if(a>140) print 140; else print a}') endif -MHA_NUM_JOBS:=$(shell awk -v n=$(CURRENT_CPU_NUM) 'BEGIN {print int(n * 0.8)}') -HDIM ?= 0 -DTYPE ?= all -# FAST_BUILD is not currently supported. -FAST_BUILD ?= 0 -GEN_KERNEL ?= 0 +MHA_NUM_JOBS:=$(shell awk -v n=$(CURRENT_CPU_NUM) 'BEGIN {print int(n*0.8)}') +# HDIM_LIST ?= 128 +HDIM_LIST ?= 128 256 +empty := +space := $(empty) $(empty) +comma := , +HDIM_CONFIG_LIST := $(subst $(space),$(comma),$(strip $(HDIM_LIST))) +DTYPE ?= BF16 +BUILD_WITH_BWD_KERNEL ?= FALSE +FWD_MN_LIST ?= DEFAULT +FWD_SPLIT_MN_LIST ?= DEFAULT +FWD_ENABLE_LOCAL ?= FALSE +FWD_ENABLE_ALIBI ?= FALSE +FWD_ENABLE_SOFTCAP ?= FALSE +FWD_ENABLE_APPENDKV ?= FALSE +FWD_ENABLE_CAUSAL ?= FALSE SUB_MODULE ?= fused_dense_lib BUILD_PROJECTS_SCRIPT_PATH := ./tools/build_scripts/build_projects_related.sh TORCH_EXTENSION_SCRIPT_PATH := ./tools/build_scripts/torch_extension_related.sh run_build_projects_script_%: - @$(BUILD_PROJECTS_SCRIPT_PATH) $* || (echo "Execution failed with code $$?") + @chmod +x $(BUILD_PROJECTS_SCRIPT_PATH) && $(BUILD_PROJECTS_SCRIPT_PATH) $* || (echo "Execution failed with code $$?") run_torch_extension_script_%: - @$(TORCH_EXTENSION_SCRIPT_PATH) $* || (echo "Execution failed with code $$?") + @chmod +x $(TORCH_EXTENSION_SCRIPT_PATH) && $(TORCH_EXTENSION_SCRIPT_PATH) $* || (echo "Execution failed with code $$?") kernel: - mkdir -p build_kernel - cd build_kernel \ - && cmake \ - -DMACA_PATH=${MACA_PATH} \ - -DBUILD_WITH_KERNEL=TRUE \ - -DBUILD_WITH_CPP=FALSE \ - -DHDIM=${HDIM} \ - -DFAST_BUILD=${FAST_BUILD} \ - -DGEN_KERNEL=$(GEN_KERNEL) \ - -DFA_TYPE=${DTYPE} \ - .. \ - && make -j$(MHA_NUM_JOBS) + @mkdir -p build_kernel build_host + @cd build_host && \ + cmake \ + -DMACA_PATH=${MACA_PATH} \ + -DBUILD_WITH_KERNEL=TRUE \ + -DBUILD_WITH_CPP=FALSE \ + -DBUILD_WITH_HOST=TRUE \ + -DBUILD_WITH_BWD_KERNEL=${BUILD_WITH_BWD_KERNEL} \ + -DFWD_MN_LIST=${FWD_MN_LIST} \ + -DFWD_SPLIT_MN_LIST=${FWD_SPLIT_MN_LIST} \ + -DFWD_ENABLE_LOCAL=${FWD_ENABLE_LOCAL} \ + -DFWD_ENABLE_ALIBI=${FWD_ENABLE_ALIBI} \ + -DFWD_ENABLE_SOFTCAP=${FWD_ENABLE_SOFTCAP} \ + -DFWD_ENABLE_APPENDKV=${FWD_ENABLE_APPENDKV} \ + -DFWD_ENABLE_CAUSAL=${FWD_ENABLE_CAUSAL} \ + -DHDIM_CONFIG_LIST=${HDIM_CONFIG_LIST} \ + -DFA_TYPE=${DTYPE} \ + .. && \ + make -j$(MHA_NUM_JOBS) mcFlashAttnHostStatic || exit 1; \ + mv libmcFlashAttnHostStatic.a ../build_kernel/; \ + cd .. && rm -rf build_host + @for hd in $(HDIM_LIST); do \ + mkdir -p build_kernel_$$hd && \ + cd build_kernel_$$hd && \ + cmake \ + -DMACA_PATH=${MACA_PATH} \ + -DBUILD_WITH_KERNEL=TRUE \ + -DBUILD_WITH_CPP=FALSE \ + -DBUILD_WITH_HOST=FALSE \ + -DHDIM=$$hd \ + -DBUILD_WITH_BWD_KERNEL=${BUILD_WITH_BWD_KERNEL} \ + -DFWD_MN_LIST=${FWD_MN_LIST} \ + -DFWD_SPLIT_MN_LIST=${FWD_SPLIT_MN_LIST} \ + -DFWD_ENABLE_LOCAL=${FWD_ENABLE_LOCAL} \ + -DFWD_ENABLE_ALIBI=${FWD_ENABLE_ALIBI} \ + -DFWD_ENABLE_SOFTCAP=${FWD_ENABLE_SOFTCAP} \ + -DFWD_ENABLE_APPENDKV=${FWD_ENABLE_APPENDKV} \ + -DFWD_ENABLE_CAUSAL=${FWD_ENABLE_CAUSAL} \ + -DFA_TYPE=${DTYPE} \ + .. && \ + make -j$(MHA_NUM_JOBS) || exit 1; \ + cd .. || exit 1; \ + done + @cd build_kernel && \ + for arch in Xcore1000 Xcore1500; do \ + if ! ls ../build_kernel_*/libmcFlashAttnKernel$${arch}Static.a 1> /dev/null 2>&1; then \ + echo "Skipping $$arch - not found"; \ + continue; \ + fi; \ + rm -rf tmp_$$arch && mkdir -p tmp_$$arch && cd tmp_$$arch && \ + for hd in $(HDIM_LIST); do \ + if [ -f ../../build_kernel_$$hd/libmcFlashAttnKernel$${arch}Static.a ]; then \ + ar x ../../build_kernel_$$hd/libmcFlashAttnKernel$${arch}Static.a && echo "Extracted from $$hd"; \ + else \ + echo "Warning: build_kernel_$$hd/libmcFlashAttnKernel$${arch}Static.a not found"; \ + fi; \ + done && \ + ls -la *.o 2>/dev/null | head -5 && \ + rm -f ../libmcFlashAttnKernel$${arch}Static.a && \ + ar -r ../libmcFlashAttnKernel$${arch}Static.a *.o && \ + cd .. && rm -rf tmp_$$arch; \ + done cplus_api: run_build_projects_script_sdk kernel mkdir -p build_cpp @@ -42,13 +102,19 @@ cplus_api: run_build_projects_script_sdk kernel -DBUILD_WITH_CPP=TRUE \ -DCMAKE_INSTALL_PREFIX=./install \ -DHDIM=${HDIM} \ - -DFAST_BUILD=${FAST_BUILD} \ - -DGEN_KERNEL=$(GEN_KERNEL) \ + -DBUILD_WITH_BWD_KERNEL=${BUILD_WITH_BWD_KERNEL} \ + -DFWD_MN_LIST=${FWD_MN_LIST} \ + -DFWD_SPLIT_MN_LIST=${FWD_SPLIT_MN_LIST} \ + -DFWD_ENABLE_LOCAL=${FWD_ENABLE_LOCAL} \ + -DFWD_ENABLE_ALIBI=${FWD_ENABLE_ALIBI} \ + -DFWD_ENABLE_SOFTCAP=${FWD_ENABLE_SOFTCAP} \ + -DFWD_ENABLE_APPENDKV=${FWD_ENABLE_APPENDKV} \ + -DFWD_ENABLE_CAUSAL=${FWD_ENABLE_CAUSAL} \ -DFA_TYPE=${DTYPE} \ .. && make -j$(MHA_NUM_JOBS) && make install python: run_build_projects_script_pytorch kernel - python ./setup.py bdist_wheel ; \ + BUILD_WITH_BWD_KERNEL=${BUILD_WITH_BWD_KERNEL} FWD_ENABLE_LOCAL=${FWD_ENABLE_LOCAL} FWD_ENABLE_ALIBI=${FWD_ENABLE_ALIBI} FWD_ENABLE_SOFTCAP=${FWD_ENABLE_SOFTCAP} FWD_ENABLE_APPENDKV=${FWD_ENABLE_APPENDKV} FWD_ENABLE_CAUSAL=${FWD_ENABLE_CAUSAL} python ./setup.py bdist_wheel ; \ mla: run_build_projects_script_pytorch mkdir -p dist @@ -64,7 +130,7 @@ clean_mla: rm -rf ./csrc/flash_mla/build clean_kernel: - rm -rf ./build_kernel + rm -rf ./build_kernel* clean_capi: rm -rf ./build_cpp diff --git a/README_MX.md b/README_MX.md index 44bb438..3d8ab50 100644 --- a/README_MX.md +++ b/README_MX.md @@ -46,15 +46,86 @@ make cplus_api make kernel ``` -### Fast build (!Currently Unavailable) -Specify hdim and dtype, and compile only the specified combinations of bool switches based on the configuration in `tools/generator/bool_switch.ini`. Refer to the comments in the file for the configuration of `bool_switch.ini`. +### Fast build + +Use `DEFAULT` to compile only the default dispatch MN tiles. Forward feature variants are disabled by default to reduce build time. Backward kernels are also disabled by default; because dropout forward is only useful together with backward in this build flow, dropout forward kernels are disabled when `BUILD_WITH_BWD_KERNEL=FALSE`. + +Running `make python` with no extra options uses these defaults: + +| Option | Default used by `make python` | Effect when changed | +| --- | --- | --- | +| `FLASHATTN_BUILD_PROJECTS` | unset | If unset, build both C500 and C600. Set `FLASHATTN_BUILD_PROJECTS=C500` or `C600` to build one architecture. | +| `HDIM_LIST` | `128 256` | Select the head dimensions to compile. | +| `DTYPE` | `BF16` | Select the dtype to compile. | +| `FWD_MN_LIST` | `DEFAULT` | Select xcore1000/xcore1500 fwd MN tiles; `DEFAULT` means the dispatch default tiles for each architecture. | +| `FWD_SPLIT_MN_LIST` | `DEFAULT` | Select xcore1000/xcore1500 fwd_split MN tiles; `DEFAULT` means the dispatch default tiles for each architecture. | +| `BUILD_WITH_BWD_KERNEL` | `FALSE` | Set to `TRUE` to build backward kernels and enable backward API support. | +| `FWD_ENABLE_LOCAL` | `FALSE` | Set to `TRUE` to build local/sliding-window forward variants. | +| `FWD_ENABLE_ALIBI` | `FALSE` | Set to `TRUE` to build ALiBi forward variants. | +| `FWD_ENABLE_SOFTCAP` | `FALSE` | Set to `TRUE` to build softcap forward variants. | +| `FWD_ENABLE_APPENDKV` | `FALSE` | Set to `TRUE` to build append-KV variants for `flash_attn_with_kvcache`. | +| `FWD_ENABLE_CAUSAL` | `FALSE` | Set to `TRUE` to build causal forward variants. | + +With these defaults, `make python` builds BF16 forward-only kernels for hdim 128 and 256, uses only the default dispatch MN tiles, disables backward/dropout, and excludes local, ALiBi, softcap, append-KV, and causal forward variants. The default MN tiles are: + +| arch | hdim | fwd default MN | fwd_split default MN | +| --- | --- | --- | --- | +| xcore1000 | 128 | 64x64 | 64x64 | +| xcore1000 | 256 | 64x32 | 64x64 | +| xcore1500 | 128 | 128x64 | 16x32, 128x64 | +| xcore1500 | 256 | 128x64 | 128x64 | + +Enable only the variants needed by your workload: +```bash +# build only C500 +FLASHATTN_BUILD_PROJECTS=C500 make python + +# build backward and dropout-capable forward kernels +make python BUILD_WITH_BWD_KERNEL=TRUE + +# support causal=True +make python FWD_ENABLE_CAUSAL=TRUE + +# support local attention and ALiBi +make python FWD_ENABLE_LOCAL=TRUE FWD_ENABLE_ALIBI=TRUE + +# support append KV in flash_attn_with_kvcache +make python FWD_ENABLE_APPENDKV=TRUE + +# enable multiple variants together +make python BUILD_WITH_BWD_KERNEL=TRUE FWD_ENABLE_CAUSAL=TRUE FWD_ENABLE_LOCAL=TRUE ``` -# fast build with generate kernel -make python HDIM=128 DTYPE=FP16 FAST_BUILD=1 GEN_KERNEL=1 -# fast build without generate kernel -make python HDIM=128 DTYPE=FP16 FAST_BUILD=1 GEN_KERNEL=0 + +Override `FWD_MN_LIST` and `FWD_SPLIT_MN_LIST` to include more forward tiles: +```bash +make python FWD_MN_LIST=64x32,64x64 FWD_SPLIT_MN_LIST=64x64 ``` +The generated xcore1000 sources currently provide these MN choices: + +| hdim | fwd MN choices | fwd_split MN choices | dispatch default | +| --- | --- | --- | --- | +| 32 | 128x64, 128x128 | 64x64 | fwd: 128x128; fwd_split: 64x64 | +| 64 | 16x16, 32x32, 64x64, 128x64, 128x128 | 16x16, 64x64 | fwd: 64x64; fwd_split: 64x64 | +| 96 | 64x64, 128x64 | 64x64 | fwd: 128x64; fwd_split: 64x64 | +| 128 | 64x32, 64x64, 128x32, 128x64 | 16x16, 32x32, 64x32, 64x64, 128x64 | fwd: 64x64; fwd_split: 64x64 | +| 160 | 64x32, 64x64, 128x64 | 64x64 | fwd: 64x32; fwd_split: 64x64 | +| 192 | 64x64; 128x64 for hdimv128 | 64x64 | fwd: 64x64 and 128x64 for hdimv128; fwd_split: 64x64 | +| 256 | 64x32, 64x64 | 64x32, 64x64 | fwd: 64x32 without dropout, 64x64 for dropout; fwd_split: 64x64 | +| 512 | 64x32 | 32x32 | fwd: 64x32; fwd_split: 32x32 | + +The generated xcore1500 sources currently provide these MN choices: + +| hdim | fwd MN choices | fwd_split MN choices | dispatch default | +| --- | --- | --- | --- | +| 32 | 128x64, 128x128 | 64x64 | fwd: 128x64 and 128x128; fwd_split: 64x64 | +| 64 | 128x64 | 128x64 | fwd: 128x64; fwd_split: 128x64 | +| 96 | 128x64 | 64x64 | fwd: 128x64; fwd_split: 64x64 | +| 128 | 128x64 | 16x32, 128x64 | fwd: 128x64; fwd_split: 16x32 for short seqlen_q, 128x64 otherwise | +| 160 | 128x64 | 64x64 | fwd: 128x64; fwd_split: 64x64 | +| 192 | 128x64 | 128x64 | fwd: 128x64; fwd_split: 128x64 | +| 256 | 128x64 | 128x64 | fwd: 128x64; fwd_split: 128x64 | + ### ‌Multi-SKU build The build process can be controlled through the environment variable `FLASHATTN_BUILD_PROJECTS`: diff --git a/benchmarks/benchmark_kvcache.py b/benchmarks/benchmark_kvcache.py new file mode 100644 index 0000000..e2e436a --- /dev/null +++ b/benchmarks/benchmark_kvcache.py @@ -0,0 +1,146 @@ +from flash_attn.flash_attn_interface import flash_attn_with_kvcache +import torch +import math +from einops import rearrange +from datetime import datetime +import csv + + +def run_with_profiler(fn, warmup=10, reps=100, print_result=False, target_kernels=None): + """Run function with torch.profiler and return sum of specific kernel times in ms""" + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA], + record_shapes=False, + profile_memory=False, + with_stack=False, + ) as prof: + for _ in range(reps): + fn() + torch.cuda.synchronize() + + if print_result: + print(prof.key_averages().table(sort_by="device_time", row_limit=20)) + + if target_kernels is None: + target_kernels = [] + + kernel_times_us = 0.0 + for evt in prof.key_averages(): + if any(k in evt.key for k in target_kernels): + kernel_times_us += evt.device_time + + ms = kernel_times_us / 1e3 + return ms + + +def calc_bandwidth(batch_size, seqlen_q, seqlen_k, num_heads, num_heads_k, headdim, dtype, ms): + """Calculate bandwidth in GB/s""" + bytes_per_elem = 2 if dtype == torch.bfloat16 else 4 + q_bytes = batch_size * seqlen_q * num_heads * headdim * bytes_per_elem + kv_bytes = batch_size * seqlen_k * num_heads_k * headdim * bytes_per_elem * 2 + total_bytes = q_bytes + kv_bytes + bw_gb_s = (total_bytes / 1e9) / (ms / 1e3) + return bw_gb_s + + +def benchmark_kvcache(batch_size, seqlen_k, seqlen_q, num_heads, num_heads_k, headdim, page_block_size, device, dtype=torch.bfloat16, causal=False): + num_blocks = math.ceil(seqlen_k / page_block_size) * batch_size * 3 + num_blocks = max(1024, num_blocks) + paged_kv_block_size = page_block_size + + nheads = num_heads + nheads_k = num_heads_k + d = headdim + + torch.random.manual_seed(0) + window_size = (-1, -1) + + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) + + k_cache_paged = torch.randn( + num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype + ) + v_cache_paged = torch.randn( + num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype + ) + block_table = rearrange( + torch.randperm(num_blocks, dtype=torch.int32, device=device), + "(b nblocks) -> b nblocks", + b=batch_size, + ) + + cache_seqlens = torch.full((batch_size,), seqlen_k, dtype=torch.int32, device=device) + + def run_fn(): + flash_attn_with_kvcache( + q, k_cache_paged, v_cache_paged, None, None, + cache_seqlens=cache_seqlens, + cache_batch_idx=None, + block_table=block_table, + causal=causal, + window_size=window_size, + rotary_interleaved=False, + alibi_slopes=None, + num_splits=1, + ) + + return run_fn + + +def main(): + headdims = [128, 256] + page_block_size = 16 + batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128] + seq_lens_kv = [512, 1024, 2048, 4096, 8192, 16384] + + device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = torch.bfloat16 + causal = False + warmup = 10 + repeat = 100 + + num_heads = 8 + num_heads_k = 8 + seqlen_q = 1 + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + csv_path = f"benchmark_kvcache_{timestamp}.csv" + + with open(csv_path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["batch_size", "seq_len_kv", "heads", "headdim", "time_ms", "bandwidth_GB_s"]) + print(f"{'batch_size':>10} {'seq_len_kv':>12} {'heads':>6} {'headdim':>8} {'time_ms':>10} {'bandwidth_GB_s':>15}") + print("-" * 75) + for headdim in headdims: + for seqlen_k in seq_lens_kv: + for batch_size in batch_sizes: + try: + run_fn = benchmark_kvcache( + batch_size=batch_size, + seqlen_k=seqlen_k, + seqlen_q=seqlen_q, + num_heads=num_heads, + num_heads_k=num_heads_k, + headdim=headdim, + page_block_size=page_block_size, + device=device, + dtype=dtype, + causal=causal, + ) + ms = run_with_profiler(run_fn, warmup=warmup, reps=repeat, target_kernels=["flash"]) + bw = calc_bandwidth(batch_size, seqlen_q, seqlen_k, num_heads, num_heads_k, headdim, dtype, ms) + writer.writerow([batch_size, seqlen_k, num_heads, headdim, f"{ms:.4f}", f"{bw:.2f}"]) + print(f"{batch_size:>10} {seqlen_k:>12} {num_heads:>6} {headdim:>8} {ms:>10.4f} {bw:>15.2f}") + except Exception as e: + writer.writerow([batch_size, seqlen_k, num_heads, headdim, "OOM", "OOM"]) + print(f"{batch_size:>10} {seqlen_k:>12} {num_heads:>6} {headdim:>8} {'OOM':>10} {'OOM':>15} # {e}") + + print(f"\nResults saved to {csv_path}") + + +if __name__ == "__main__": + main() diff --git a/cmake/build_flash_attn.cmake b/cmake/build_flash_attn.cmake index 929d142..d3b71c0 100644 --- a/cmake/build_flash_attn.cmake +++ b/cmake/build_flash_attn.cmake @@ -1,3 +1,116 @@ +function(_flash_attn_get_dispatch_mn OUT_VAR MACA_ARCH SOURCE_KIND HDIM) + if("${MACA_ARCH}" STREQUAL "xcore1000") + if("${SOURCE_KIND}" STREQUAL "fwd") + if("${HDIM}" STREQUAL "0") + set(MN_LIST "32x32,64x32,64x64,128x64,128x128") + elseif("${HDIM}" STREQUAL "32") + set(MN_LIST "128x128") + elseif("${HDIM}" STREQUAL "64") + set(MN_LIST "64x64") + elseif("${HDIM}" STREQUAL "96") + set(MN_LIST "128x64") + elseif("${HDIM}" STREQUAL "128") + set(MN_LIST "64x64") + elseif("${HDIM}" STREQUAL "160") + set(MN_LIST "64x32") + elseif("${HDIM}" STREQUAL "192") + set(MN_LIST "64x64,128x64") + elseif("${HDIM}" STREQUAL "256") + if(BUILD_WITH_BWD_KERNEL) + set(MN_LIST "64x32,64x64") + else() + set(MN_LIST "64x32") + endif() + elseif("${HDIM}" STREQUAL "512") + set(MN_LIST "64x32") + else() + set(MN_LIST "") + endif() + elseif("${SOURCE_KIND}" STREQUAL "fwd_split") + if("${HDIM}" STREQUAL "0") + set(MN_LIST "32x32,64x64") + elseif("${HDIM}" STREQUAL "512") + set(MN_LIST "32x32") + else() + set(MN_LIST "64x64") + endif() + else() + message(FATAL_ERROR "Unknown ${MACA_ARCH} fwd source kind: ${SOURCE_KIND}") + endif() + elseif("${MACA_ARCH}" STREQUAL "xcore1500") + if("${SOURCE_KIND}" STREQUAL "fwd") + if("${HDIM}" STREQUAL "0") + set(MN_LIST "128x64,128x128") + elseif("${HDIM}" STREQUAL "32") + set(MN_LIST "128x64,128x128") + elseif("${HDIM}" STREQUAL "512") + set(MN_LIST "") + else() + set(MN_LIST "128x64") + endif() + elseif("${SOURCE_KIND}" STREQUAL "fwd_split") + if("${HDIM}" STREQUAL "0") + set(MN_LIST "16x32,64x64,128x64") + elseif("${HDIM}" STREQUAL "32") + set(MN_LIST "64x64") + elseif("${HDIM}" STREQUAL "64") + set(MN_LIST "128x64") + elseif("${HDIM}" STREQUAL "96") + set(MN_LIST "64x64") + elseif("${HDIM}" STREQUAL "128") + set(MN_LIST "16x32,128x64") + elseif("${HDIM}" STREQUAL "160") + set(MN_LIST "64x64") + elseif("${HDIM}" STREQUAL "192") + set(MN_LIST "128x64") + elseif("${HDIM}" STREQUAL "256") + set(MN_LIST "128x64") + else() + set(MN_LIST "") + endif() + else() + message(FATAL_ERROR "Unknown ${MACA_ARCH} fwd source kind: ${SOURCE_KIND}") + endif() + else() + set(MN_LIST "") + endif() + + set(${OUT_VAR} "${MN_LIST}" PARENT_SCOPE) +endfunction() + +function(_flash_attn_filter_mn_sources SOURCE_VAR MN_LIST LABEL MACA_ARCH) + if("${MN_LIST}" STREQUAL "") + return() + endif() + + string(REPLACE "," ";" FWD_MN_ITEMS "${MN_LIST}") + set(FWD_MN_REGEX "") + foreach(FWD_MN_ITEM IN LISTS FWD_MN_ITEMS) + string(STRIP "${FWD_MN_ITEM}" FWD_MN_ITEM) + if("${FWD_MN_ITEM}" STREQUAL "") + continue() + endif() + string(TOLOWER "${FWD_MN_ITEM}" FWD_MN_ITEM) + string(REPLACE "x" "n" FWD_MN_TOKEN "${FWD_MN_ITEM}") + set(FWD_MN_TOKEN "m${FWD_MN_TOKEN}") + if("${FWD_MN_REGEX}" STREQUAL "") + set(FWD_MN_REGEX "_(${FWD_MN_TOKEN}") + else() + set(FWD_MN_REGEX "${FWD_MN_REGEX}|${FWD_MN_TOKEN}") + endif() + endforeach() + + if("${FWD_MN_REGEX}" STREQUAL "") + return() + endif() + + set(FWD_MN_REGEX "${FWD_MN_REGEX})_") + set(SOURCES ${${SOURCE_VAR}}) + list(FILTER SOURCES INCLUDE REGEX "${FWD_MN_REGEX}") + set(${SOURCE_VAR} "${SOURCES}" PARENT_SCOPE) + message(STATUS "Filter ${MACA_ARCH} ${LABEL} kernels by MN_LIST=${MN_LIST}, regex=${FWD_MN_REGEX}") +endfunction() + macro(build_flash_attn_kernel LIB_NAME MACA_ARCH HDIM DTYPE) set(HDIM_FILTER "") @@ -24,12 +137,53 @@ macro(build_flash_attn_kernel LIB_NAME MACA_ARCH HDIM DTYPE) file(GLOB FWD_TRAITS_SRC "${SRC_PARENT}/run_flash_template/${MACA_ARCH}/fwd/flash_fwd_hdimqk${SRC_FILTER}.cpp") file(GLOB FWD_SPLIT_TRAITS_SRC "${SRC_PARENT}/run_flash_template/${MACA_ARCH}/fwd_split/flash_fwd_splitkv_hdimqk${SRC_FILTER}.cpp") - file(GLOB BWD_TRAITS_SRC "${SRC_PARENT}/run_flash_template/${MACA_ARCH}/bwd/flash_bwd_hdimqk${SRC_FILTER}.cpp") file(GLOB FWD_KERNEL_SRC "${SRC_PARENT}/full_kernels/${MACA_ARCH}/fwd/flash_fwd_hdimqk${SRC_FILTER}.cpp") file(GLOB FWD_SPLIT_KERNEL_SRC "${SRC_PARENT}/full_kernels/${MACA_ARCH}/fwd_split/flash_fwd_splitkv_hdimqk${SRC_FILTER}.cpp") - file(GLOB BWD_KERNEL_SRC "${SRC_PARENT}/full_kernels/${MACA_ARCH}/bwd/flash_bwd_hdimqk${SRC_FILTER}.cpp") + if("${MACA_ARCH}" STREQUAL "xcore1000" OR "${MACA_ARCH}" STREQUAL "xcore1500") + set(FWD_MN_EFFECTIVE "${FWD_MN_LIST}") + string(TOUPPER "${FWD_MN_EFFECTIVE}" FWD_MN_EFFECTIVE_UPPER) + if("${FWD_MN_EFFECTIVE_UPPER}" STREQUAL "DEFAULT") + _flash_attn_get_dispatch_mn(FWD_MN_EFFECTIVE "${MACA_ARCH}" "fwd" "${HDIM}") + endif() + + set(FWD_SPLIT_MN_EFFECTIVE "${FWD_SPLIT_MN_LIST}") + string(TOUPPER "${FWD_SPLIT_MN_EFFECTIVE}" FWD_SPLIT_MN_EFFECTIVE_UPPER) + if("${FWD_SPLIT_MN_EFFECTIVE_UPPER}" STREQUAL "DEFAULT") + _flash_attn_get_dispatch_mn(FWD_SPLIT_MN_EFFECTIVE "${MACA_ARCH}" "fwd_split" "${HDIM}") + endif() + + _flash_attn_filter_mn_sources(FWD_TRAITS_SRC "${FWD_MN_EFFECTIVE}" "fwd traits" "${MACA_ARCH}") + _flash_attn_filter_mn_sources(FWD_KERNEL_SRC "${FWD_MN_EFFECTIVE}" "fwd full" "${MACA_ARCH}") + _flash_attn_filter_mn_sources(FWD_SPLIT_TRAITS_SRC "${FWD_SPLIT_MN_EFFECTIVE}" "fwd_split traits" "${MACA_ARCH}") + _flash_attn_filter_mn_sources(FWD_SPLIT_KERNEL_SRC "${FWD_SPLIT_MN_EFFECTIVE}" "fwd_split full" "${MACA_ARCH}") + endif() + + if(NOT BUILD_WITH_BWD_KERNEL) + list(FILTER FWD_KERNEL_SRC EXCLUDE REGEX "_dropout_") + message(STATUS "Disable dropout fwd kernels because BUILD_WITH_BWD_KERNEL is OFF") + endif() + if(NOT FWD_ENABLE_LOCAL) + list(FILTER FWD_SPLIT_KERNEL_SRC EXCLUDE REGEX "_Is_local") + message(STATUS "Disable local fwd_split full kernels") + endif() + if(NOT FWD_ENABLE_ALIBI) + list(FILTER FWD_SPLIT_KERNEL_SRC EXCLUDE REGEX "_alibi") + message(STATUS "Disable alibi fwd_split full kernels") + endif() + if(NOT FWD_ENABLE_CAUSAL) + list(FILTER FWD_SPLIT_KERNEL_SRC EXCLUDE REGEX "_causal") + message(STATUS "Disable causal fwd_split full kernels") + endif() + + if(BUILD_WITH_BWD_KERNEL) + file(GLOB BWD_TRAITS_SRC "${SRC_PARENT}/run_flash_template/${MACA_ARCH}/bwd/flash_bwd_hdimqk${SRC_FILTER}.cpp") + file(GLOB BWD_KERNEL_SRC "${SRC_PARENT}/full_kernels/${MACA_ARCH}/bwd/flash_bwd_hdimqk${SRC_FILTER}.cpp") + else() + set(BWD_TRAITS_SRC "") + set(BWD_KERNEL_SRC "") + endif() add_library(${LIB_NAME} STATIC ${FWD_KERNEL_SRC} @@ -48,11 +202,15 @@ macro(build_flash_attn_host LIB_NAME) csrc/flash_attn/flash_run/flash_performance_mode.cpp csrc/flash_attn/flash_run/flash_launch_parameter.cpp csrc/flash_attn/flash_run/run_mha_fwd.cpp - csrc/flash_attn/flash_run/run_mha_bwd.cpp csrc/flash_attn/utils/print_parameter.cpp csrc/common/process_str.cpp csrc/common/logger.cpp ) + if(BUILD_WITH_BWD_KERNEL) + list(APPEND FLASH_ATTN_SRC csrc/flash_attn/flash_run/run_mha_bwd.cpp) + else() + message(STATUS "Disable backward host sources because BUILD_WITH_BWD_KERNEL is OFF") + endif() add_library(${LIB_NAME} STATIC ${FLASH_ATTN_SRC} ) diff --git a/csrc/flash_attn/flash_api/flash_api_bwd.cpp b/csrc/flash_attn/flash_api/flash_api_bwd.cpp index 0c095f0..dc76e4b 100644 --- a/csrc/flash_attn/flash_api/flash_api_bwd.cpp +++ b/csrc/flash_attn/flash_api/flash_api_bwd.cpp @@ -6,6 +6,62 @@ using namespace mcFlashAttn; +#ifdef FLASHATTENTION_DISABLE_BACKWARD + +std::vector +mha_bwd(const at::Tensor &dout, + const at::Tensor &q, + const at::Tensor &k, + const at::Tensor &v, + const at::Tensor &out, + const at::Tensor &softmax_lse, + c10::optional &dq_, + c10::optional &dk_, + c10::optional &dv_, + c10::optional &alibi_slopes_, + c10::optional &attn_mask_, + const float p_dropout, + const float softmax_scale, + const bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool deterministic, + c10::optional gen_, + c10::optional &rng_state) { + TORCH_CHECK(false, "This flash attention build does not support backward."); +} + +std::vector +mha_varlen_bwd(const at::Tensor &dout, + const at::Tensor &q, + const at::Tensor &k, + const at::Tensor &v, + const at::Tensor &out, + const at::Tensor &softmax_lse, + c10::optional &dq_, + c10::optional &dk_, + c10::optional &dv_, + const at::Tensor &cu_seqlens_q, + const at::Tensor &cu_seqlens_k, + c10::optional &alibi_slopes_, + const int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool deterministic, + c10::optional gen_, + c10::optional &rng_state) { + TORCH_CHECK(false, "This flash attention build does not support backward."); +} + +#else + /* *@attn_mask_ support [batch_size or 1, num_heads or 1, seqlen_q or 1, seqlen_k or 1] * [num_heads or 1, seqlen_q or 1, seqlen_k or 1] @@ -495,3 +551,4 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size return { dq, dk, dv, softmax_d }; } +#endif diff --git a/csrc/flash_attn/flash_api/flash_api_fwd.cpp b/csrc/flash_attn/flash_api/flash_api_fwd.cpp index c81b2f4..838e539 100644 --- a/csrc/flash_attn/flash_api/flash_api_fwd.cpp +++ b/csrc/flash_attn/flash_api/flash_api_fwd.cpp @@ -69,12 +69,18 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } +#ifdef FLASHATTENTION_DISABLE_DROPOUT + TORCH_CHECK(p_dropout == 0.f, "This flash attention build does not support dropout."); +#endif if (window_size_left >= seqlen_k) { window_size_left = -1; } if (window_size_right >= seqlen_k) { window_size_right = -1; } // causal=true is the same as causal=false in this case if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } +#ifdef FLASHATTENTION_DISABLE_CAUSAL + TORCH_CHECK(!is_causal, "This flash attention build does not support causal attention."); +#endif if (is_causal) { window_size_right = 0; } // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case @@ -339,6 +345,9 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } +#ifdef FLASHATTENTION_DISABLE_DROPOUT + TORCH_CHECK(p_dropout == 0.f, "This flash attention build does not support dropout."); +#endif const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); const int num_blocks = !paged_KV ? 0 : k.size(0); @@ -346,6 +355,9 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s TORCH_CHECK(!paged_KV || (page_block_size > 0 && (page_block_size & (page_block_size - 1)) == 0), "Paged KV cache block size must be a power of 2!"); if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case +#ifdef FLASHATTENTION_DISABLE_CAUSAL + TORCH_CHECK(!is_causal, "This flash attention build does not support causal attention."); +#endif if (is_causal) { window_size_right = 0; } void *cu_seqlens_q_d = cu_seqlens_q.data_ptr(); diff --git a/csrc/flash_attn/flash_api/flash_api_fwd_kvcache.cpp b/csrc/flash_attn/flash_api/flash_api_fwd_kvcache.cpp index fa6a330..8cd55b3 100644 --- a/csrc/flash_attn/flash_api/flash_api_fwd_kvcache.cpp +++ b/csrc/flash_attn/flash_api/flash_api_fwd_kvcache.cpp @@ -76,6 +76,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he // causal=true is the same as causal=false in this case if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } +#ifdef FLASHATTENTION_DISABLE_CAUSAL + TORCH_CHECK(!is_causal, "This flash attention build does not support causal attention."); +#endif if (is_causal) { window_size_right = 0; } // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case @@ -167,6 +170,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he at::Tensor k, v, k_padded, v_padded; if (k_.has_value()) { +#ifdef FLASHATTENTION_DISABLE_APPENDKV + TORCH_CHECK(false, "This flash attention build does not support append KV."); +#endif TORCH_CHECK(v_.has_value(), "If key is supplied, value must also be passed in"); TORCH_CHECK(seqlens_k_.has_value(), "If key is supplied, seqlens_k must also be passed in"); TORCH_CHECK(seqlen_q <= seqlen_k, "If key is supplied, it must have seqlen <= the seqlen of the KV cache"); @@ -380,6 +386,9 @@ mha_fwd_kvcache_dequant(at::Tensor &q, // batch_size x seqlen_q x // causal=true is the same as causal=false in this case if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } +#ifdef FLASHATTENTION_DISABLE_CAUSAL + TORCH_CHECK(!is_causal, "This flash attention build does not support causal attention."); +#endif if (is_causal) { window_size_right = 0; } // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case @@ -487,6 +496,9 @@ mha_fwd_kvcache_dequant(at::Tensor &q, // batch_size x seqlen_q x at::Tensor k, v, k_padded, v_padded; if (k_.has_value()) { +#ifdef FLASHATTENTION_DISABLE_APPENDKV + TORCH_CHECK(false, "This flash attention build does not support append KV."); +#endif TORCH_CHECK(v_.has_value(), "If key is supplied, value must also be passed in"); TORCH_CHECK(seqlens_k_.has_value(), "If key is supplied, seqlens_k must also be passed in"); TORCH_CHECK(seqlen_q <= seqlen_k, "If key is supplied, it must have seqlen <= the seqlen of the KV cache"); diff --git a/csrc/flash_attn/flash_dispatch/flash_fwd_dispatch_template.h b/csrc/flash_attn/flash_dispatch/flash_fwd_dispatch_template.h index adbf8d8..eb185ba 100644 --- a/csrc/flash_attn/flash_dispatch/flash_fwd_dispatch_template.h +++ b/csrc/flash_attn/flash_dispatch/flash_fwd_dispatch_template.h @@ -502,6 +502,9 @@ inline void run_mha_fwd_dispatch<256, Arch::xcore1000>(Flash_fwd_params ¶ms, launch_params.rowblock_parallel = 0; launch_params.block_type = 5; FP16_SWITCH(!params.is_bf16, [&] { +#ifdef FLASHATTENTION_DISABLE_DROPOUT + Xcore1000::run_flash_fwd_template(params, launch_params, stream); +#else bool is_dropout = params.p_dropout < 1.f; if (!is_dropout) { Xcore1000::run_flash_fwd_template(params, launch_params, stream); @@ -509,6 +512,7 @@ inline void run_mha_fwd_dispatch<256, Arch::xcore1000>(Flash_fwd_params ¶ms, else { Xcore1000::run_flash_fwd_template(params, launch_params, stream); } +#endif }); } diff --git a/csrc/flash_attn/flash_run/run_mha_bwd.cpp b/csrc/flash_attn/flash_run/run_mha_bwd.cpp index 45d82cc..89ccb8f 100644 --- a/csrc/flash_attn/flash_run/run_mha_bwd.cpp +++ b/csrc/flash_attn/flash_run/run_mha_bwd.cpp @@ -2,12 +2,20 @@ #include "run_mha.h" #include "hdim_switch.h" #include "static_switch.h" +#ifdef FLASHATTENTION_DISABLE_BACKWARD +#include +#else #include "flash_bwd_dispatch_template.h" +#endif void run_mha_bwd(mcFlashAttn::Flash_bwd_params ¶ms, cudaStream_t stream) { +#ifdef FLASHATTENTION_DISABLE_BACKWARD + throw std::runtime_error("This flash attention build does not support backward."); +#else HEADDIM_SWITCH(params.d, { ARCH_SWITCH(params.arch, kArch, [&] { run_mha_bwd_dispatch(params,stream); }); }); +#endif } diff --git a/setup.py b/setup.py index a3ef671..e963f8d 100644 --- a/setup.py +++ b/setup.py @@ -48,6 +48,14 @@ SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE" # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE" +BUILD_WITH_BWD_KERNEL = os.getenv("BUILD_WITH_BWD_KERNEL", "FALSE").upper() not in ("0", "FALSE", "OFF", "NO") +DISABLE_BACKWARD = not BUILD_WITH_BWD_KERNEL +DISABLE_DROPOUT = os.getenv("FLASHATTENTION_DISABLE_DROPOUT", "FALSE").upper() in ("1", "TRUE", "ON", "YES") or DISABLE_BACKWARD +FWD_ENABLE_LOCAL = os.getenv("FWD_ENABLE_LOCAL", "FALSE").upper() in ("1", "TRUE", "ON", "YES") +FWD_ENABLE_ALIBI = os.getenv("FWD_ENABLE_ALIBI", "FALSE").upper() in ("1", "TRUE", "ON", "YES") +FWD_ENABLE_SOFTCAP = os.getenv("FWD_ENABLE_SOFTCAP", "FALSE").upper() in ("1", "TRUE", "ON", "YES") +FWD_ENABLE_APPENDKV = os.getenv("FWD_ENABLE_APPENDKV", "FALSE").upper() in ("1", "TRUE", "ON", "YES") +FWD_ENABLE_CAUSAL = os.getenv("FWD_ENABLE_CAUSAL", "FALSE").upper() in ("1", "TRUE", "ON", "YES") def get_platform(): @@ -153,6 +161,19 @@ def append_nvcc_threads(nvcc_extra_args): # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 if FORCE_CXX11_ABI: torch._C._GLIBCXX_USE_CXX11_ABI = True + backward_flag = ["-DFLASHATTENTION_DISABLE_BACKWARD"] if DISABLE_BACKWARD else [] + dropout_flag = ["-DFLASHATTENTION_DISABLE_DROPOUT", "-DDROPOUT_FALSE"] if DISABLE_DROPOUT else [] + feature_flags = [] + if not FWD_ENABLE_LOCAL: + feature_flags += ["-DFLASHATTENTION_DISABLE_LOCAL", "-DLOCAL_FALSE"] + if not FWD_ENABLE_ALIBI: + feature_flags += ["-DFLASHATTENTION_DISABLE_ALIBI", "-DALIBI_FALSE"] + if not FWD_ENABLE_SOFTCAP: + feature_flags += ["-DFLASHATTENTION_DISABLE_SOFTCAP", "-DSOFTCAP_FALSE"] + if not FWD_ENABLE_APPENDKV: + feature_flags += ["-DFLASHATTENTION_DISABLE_APPENDKV", "-DAPPENDKV_FALSE"] + if not FWD_ENABLE_CAUSAL: + feature_flags += ["-DFLASHATTENTION_DISABLE_CAUSAL", "-DCAUSAL_FALSE"] ext_modules.append( CUDAExtension( name="flash_attn_2_cuda", @@ -165,7 +186,7 @@ def append_nvcc_threads(nvcc_extra_args): "csrc/flash_attn/flash_api/flash_splitkv.cpp", ], extra_compile_args={ - "cxx": ["-O3", "-std=c++17", "-w"] + generator_flag, + "cxx": ["-O3", "-std=c++17", "-w"] + generator_flag + backward_flag + dropout_flag + feature_flags, "nvcc": append_nvcc_threads( [ "-O3", @@ -203,6 +224,9 @@ def append_nvcc_threads(nvcc_extra_args): ] + generator_flag + cc_flag + + backward_flag + + dropout_flag + + feature_flags ), }, extra_link_args=["-T", "a.lds"], diff --git a/test_jobs.sh b/test_jobs.sh new file mode 100644 index 0000000..3e9c231 --- /dev/null +++ b/test_jobs.sh @@ -0,0 +1,5 @@ +#!/bin/bash +CURRENT_CPU_NUM=$(grep "cpu cores" /proc/cpuinfo | head -1 | awk '{print $4}') +MHA_NUM_JOBS=$(awk -v n=$CURRENT_CPU_NUM 'BEGIN {n=int(n*0.8); if(n>12) print 12; else print n}') +echo "CURRENT_CPU_NUM=$CURRENT_CPU_NUM" +echo "MHA_NUM_JOBS=$MHA_NUM_JOBS" \ No newline at end of file diff --git a/tools/build_scripts/build_projects_related.sh b/tools/build_scripts/build_projects_related.sh old mode 100644 new mode 100755