diff --git a/llama/patches/0032-ggml-enable-MLA-flash-attention-for-GLM-4.7-flash.patch b/llama/patches/0032-ggml-enable-MLA-flash-attention-for-GLM-4.7-flash.patch index 0d59e8fbd..abd7df930 100644 --- a/llama/patches/0032-ggml-enable-MLA-flash-attention-for-GLM-4.7-flash.patch +++ b/llama/patches/0032-ggml-enable-MLA-flash-attention-for-GLM-4.7-flash.patch @@ -17,21 +17,22 @@ CUDA changes: - Add tile configs for (576, 512, 4) and (576, 512, 8) - Add MMA config cases for ncols 4 - Add template instances for ncols2=4 +- Fix nbatch_fa values in nvidia_fp32 config (32->64) --- - ggml/src/ggml-cuda/fattn-mma-f16.cuh | 15 ++++++++++++--- - ggml/src/ggml-cuda/fattn-tile.cuh | 18 +++++++++++++++++- - ggml/src/ggml-cuda/fattn.cu | 12 ++++++++---- - ...attn-mma-f16-instance-ncols1_16-ncols2_4.cu | 1 + - ...fattn-mma-f16-instance-ncols1_2-ncols2_4.cu | 1 + - ...fattn-mma-f16-instance-ncols1_4-ncols2_4.cu | 1 + - ...fattn-mma-f16-instance-ncols1_8-ncols2_4.cu | 1 + - ggml/src/ggml-metal/ggml-metal-device.m | 8 ++------ - ggml/src/ggml-metal/ggml-metal-ops.cpp | 2 +- - ggml/src/ggml-metal/ggml-metal.metal | 1 + - 10 files changed, 45 insertions(+), 15 deletions(-) + ggml/src/ggml-cuda/fattn-mma-f16.cuh | 40 +++++++++++++++---- + ggml/src/ggml-cuda/fattn-tile.cuh | 16 ++++++++ + ggml/src/ggml-cuda/fattn.cu | 12 ++++-- + ...ttn-mma-f16-instance-ncols1_16-ncols2_4.cu | 1 + + ...attn-mma-f16-instance-ncols1_2-ncols2_4.cu | 1 + + ...attn-mma-f16-instance-ncols1_4-ncols2_4.cu | 1 + + ...attn-mma-f16-instance-ncols1_8-ncols2_4.cu | 1 + + ggml/src/ggml-metal/ggml-metal-device.m | 8 +--- + ggml/src/ggml-metal/ggml-metal-ops.cpp | 2 +- + ggml/src/ggml-metal/ggml-metal.metal | 1 + + 10 files changed, 64 insertions(+), 19 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh -index 7bd1044c1..a627302f9 100644 +index 7bd1044c1..3dea2205e 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -66,7 +66,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co @@ -64,18 +65,78 @@ index 7bd1044c1..a627302f9 100644 GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false); -@@ -1585,3 +1588,9 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64) +@@ -397,7 +400,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( + constexpr int ncols = ncols1 * ncols2; + constexpr int cols_per_warp = T_B_KQ::I; + constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column. +- constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. ++ constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column. + constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols); + constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols); + constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols); +@@ -467,7 +470,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( + } + } + } else { +- static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented"); + #pragma unroll + for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) { + load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q); +@@ -479,8 +481,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( + T_A_KQ K_A; + load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); + +- // Wide version of KQ_C is column-major => swap A and B. +- mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A); ++ if constexpr (cols_per_warp == 8) { ++ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]); ++ } else { ++ // Wide version of KQ_C is column-major ++#if defined(AMD_WMMA_AVAILABLE) ++ // RDNA matrix C is column-major. ++ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]); ++#else ++ // swap A and B for CUDA. ++ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A); ++#endif // defined(AMD_WMMA_AVAILABLE) ++ } + } + } + } +@@ -841,7 +853,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( + + constexpr int cols_per_warp = T_B_KQ::I; + constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column. +- constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. ++ constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column. + constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols); + constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols); + constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols); +@@ -1353,6 +1365,13 @@ static __global__ void flash_attn_ext_f16( + NO_DEVICE_CODE; + return; + } ++#ifdef VOLTA_MMA_AVAILABLE ++ if (ncols1*ncols2 < 32) { ++ NO_DEVICE_CODE; ++ return; ++ } ++#endif // VOLTA_MMA_AVAILABLE ++ + #if __CUDA_ARCH__ == GGML_CUDA_CC_TURING + if (ncols1*ncols2 > 32) { + NO_DEVICE_CODE; +@@ -1585,3 +1604,8 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64) extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); + -+// GLM 4.7 Flash uses gqa_ratio 4: -+extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4); -+extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); -+extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); -+extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4); ++// For GLM 4.7 Flash ++extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); ++extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); ++extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4); diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh -index 7c4d6fe67..6389ba5c4 100644 +index 7c4d6fe67..371be7442 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv @@ -85,19 +146,17 @@ index 7c4d6fe67..6389ba5c4 100644 + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) - - return 0; -@@ -122,7 +124,9 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv - GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) - -- GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64) -+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64) -+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64) -+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64) return 0; - } +@@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) + ++ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64) ++ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64) + + return 0; @@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128) @@ -106,11 +165,11 @@ index 7c4d6fe67..6389ba5c4 100644 + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64) - + @@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128) - + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64) diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh index a627302f9..3dea2205e 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -400,7 +400,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( constexpr int ncols = ncols1 * ncols2; constexpr int cols_per_warp = T_B_KQ::I; constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column. - constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column. constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols); constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols); constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols); @@ -470,7 +470,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } } else { - static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented"); #pragma unroll for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) { load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q); @@ -482,8 +481,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( T_A_KQ K_A; load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); - // Wide version of KQ_C is column-major => swap A and B. - mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A); + if constexpr (cols_per_warp == 8) { + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]); + } else { + // Wide version of KQ_C is column-major +#if defined(AMD_WMMA_AVAILABLE) + // RDNA matrix C is column-major. + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]); +#else + // swap A and B for CUDA. + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A); +#endif // defined(AMD_WMMA_AVAILABLE) + } } } } @@ -844,7 +853,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr int cols_per_warp = T_B_KQ::I; constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column. - constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column. constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols); constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols); constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols); @@ -1356,6 +1365,13 @@ static __global__ void flash_attn_ext_f16( NO_DEVICE_CODE; return; } +#ifdef VOLTA_MMA_AVAILABLE + if (ncols1*ncols2 < 32) { + NO_DEVICE_CODE; + return; + } +#endif // VOLTA_MMA_AVAILABLE + #if __CUDA_ARCH__ == GGML_CUDA_CC_TURING if (ncols1*ncols2 > 32) { NO_DEVICE_CODE; @@ -1589,8 +1605,7 @@ extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); -// GLM 4.7 Flash uses gqa_ratio 4: -extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4); -extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); -extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); -extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4); +// For GLM 4.7 Flash +extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);