mirror of
https://github.com/ollama/ollama.git
synced 2026-01-29 07:12:03 +03:00
llama: fix CUDA MMA errors in release build (#13874)
This commit is contained in:
@@ -17,21 +17,22 @@ CUDA changes:
|
|||||||
- Add tile configs for (576, 512, 4) and (576, 512, 8)
|
- Add tile configs for (576, 512, 4) and (576, 512, 8)
|
||||||
- Add MMA config cases for ncols 4
|
- Add MMA config cases for ncols 4
|
||||||
- Add template instances for ncols2=4
|
- Add template instances for ncols2=4
|
||||||
|
- Fix nbatch_fa values in nvidia_fp32 config (32->64)
|
||||||
---
|
---
|
||||||
ggml/src/ggml-cuda/fattn-mma-f16.cuh | 15 ++++++++++++---
|
ggml/src/ggml-cuda/fattn-mma-f16.cuh | 40 +++++++++++++++----
|
||||||
ggml/src/ggml-cuda/fattn-tile.cuh | 18 +++++++++++++++++-
|
ggml/src/ggml-cuda/fattn-tile.cuh | 16 ++++++++
|
||||||
ggml/src/ggml-cuda/fattn.cu | 12 ++++++++----
|
ggml/src/ggml-cuda/fattn.cu | 12 ++++--
|
||||||
...attn-mma-f16-instance-ncols1_16-ncols2_4.cu | 1 +
|
...ttn-mma-f16-instance-ncols1_16-ncols2_4.cu | 1 +
|
||||||
...fattn-mma-f16-instance-ncols1_2-ncols2_4.cu | 1 +
|
...attn-mma-f16-instance-ncols1_2-ncols2_4.cu | 1 +
|
||||||
...fattn-mma-f16-instance-ncols1_4-ncols2_4.cu | 1 +
|
...attn-mma-f16-instance-ncols1_4-ncols2_4.cu | 1 +
|
||||||
...fattn-mma-f16-instance-ncols1_8-ncols2_4.cu | 1 +
|
...attn-mma-f16-instance-ncols1_8-ncols2_4.cu | 1 +
|
||||||
ggml/src/ggml-metal/ggml-metal-device.m | 8 ++------
|
ggml/src/ggml-metal/ggml-metal-device.m | 8 +---
|
||||||
ggml/src/ggml-metal/ggml-metal-ops.cpp | 2 +-
|
ggml/src/ggml-metal/ggml-metal-ops.cpp | 2 +-
|
||||||
ggml/src/ggml-metal/ggml-metal.metal | 1 +
|
ggml/src/ggml-metal/ggml-metal.metal | 1 +
|
||||||
10 files changed, 45 insertions(+), 15 deletions(-)
|
10 files changed, 64 insertions(+), 19 deletions(-)
|
||||||
|
|
||||||
diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
|
diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
|
||||||
index 7bd1044c1..a627302f9 100644
|
index 7bd1044c1..3dea2205e 100644
|
||||||
--- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh
|
--- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh
|
||||||
+++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
|
+++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
|
||||||
@@ -66,7 +66,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
@@ -66,7 +66,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||||
@@ -64,18 +65,78 @@ index 7bd1044c1..a627302f9 100644
|
|||||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
|
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
|
||||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
|
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
|
||||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false);
|
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false);
|
||||||
@@ -1585,3 +1588,9 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
|
@@ -397,7 +400,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||||
|
constexpr int ncols = ncols1 * ncols2;
|
||||||
|
constexpr int cols_per_warp = T_B_KQ::I;
|
||||||
|
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
|
||||||
|
- constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
||||||
|
+ constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
|
||||||
|
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
|
||||||
|
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
|
||||||
|
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
|
||||||
|
@@ -467,7 +470,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
- static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
|
||||||
|
#pragma unroll
|
||||||
|
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
|
||||||
|
load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
|
||||||
|
@@ -479,8 +481,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||||
|
T_A_KQ K_A;
|
||||||
|
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
|
||||||
|
|
||||||
|
- // Wide version of KQ_C is column-major => swap A and B.
|
||||||
|
- mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
|
||||||
|
+ if constexpr (cols_per_warp == 8) {
|
||||||
|
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
||||||
|
+ } else {
|
||||||
|
+ // Wide version of KQ_C is column-major
|
||||||
|
+#if defined(AMD_WMMA_AVAILABLE)
|
||||||
|
+ // RDNA matrix C is column-major.
|
||||||
|
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
||||||
|
+#else
|
||||||
|
+ // swap A and B for CUDA.
|
||||||
|
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
|
||||||
|
+#endif // defined(AMD_WMMA_AVAILABLE)
|
||||||
|
+ }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@@ -841,7 +853,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||||
|
|
||||||
|
constexpr int cols_per_warp = T_B_KQ::I;
|
||||||
|
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
|
||||||
|
- constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
||||||
|
+ constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
|
||||||
|
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
|
||||||
|
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
|
||||||
|
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
|
||||||
|
@@ -1353,6 +1365,13 @@ static __global__ void flash_attn_ext_f16(
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
+#ifdef VOLTA_MMA_AVAILABLE
|
||||||
|
+ if (ncols1*ncols2 < 32) {
|
||||||
|
+ NO_DEVICE_CODE;
|
||||||
|
+ return;
|
||||||
|
+ }
|
||||||
|
+#endif // VOLTA_MMA_AVAILABLE
|
||||||
|
+
|
||||||
|
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||||
|
if (ncols1*ncols2 > 32) {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
@@ -1585,3 +1604,8 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
|
||||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
|
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
|
||||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
|
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
|
||||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
|
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
|
||||||
+
|
+
|
||||||
+// GLM 4.7 Flash uses gqa_ratio 4:
|
+// For GLM 4.7 Flash
|
||||||
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
|
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
||||||
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
||||||
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
||||||
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
|
||||||
diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh
|
diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh
|
||||||
index 7c4d6fe67..6389ba5c4 100644
|
index 7c4d6fe67..371be7442 100644
|
||||||
--- a/ggml/src/ggml-cuda/fattn-tile.cuh
|
--- a/ggml/src/ggml-cuda/fattn-tile.cuh
|
||||||
+++ b/ggml/src/ggml-cuda/fattn-tile.cuh
|
+++ b/ggml/src/ggml-cuda/fattn-tile.cuh
|
||||||
@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||||
@@ -85,19 +146,17 @@ index 7c4d6fe67..6389ba5c4 100644
|
|||||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||||
|
|
||||||
return 0;
|
|
||||||
@@ -122,7 +124,9 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
|
||||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
|
||||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
|
|
||||||
|
|
||||||
- GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
|
|
||||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64)
|
|
||||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
|
|
||||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
|
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
@@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||||
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||||
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
|
||||||
|
|
||||||
|
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64)
|
||||||
|
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
|
||||||
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
|
||||||
|
|
||||||
|
return 0;
|
||||||
@@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
@@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
|
||||||
@@ -106,11 +165,11 @@ index 7c4d6fe67..6389ba5c4 100644
|
|||||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
|
||||||
|
|
||||||
@@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
@@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
|
||||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
|
||||||
|
|
||||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
|
||||||
|
|||||||
@@ -400,7 +400,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||||||
constexpr int ncols = ncols1 * ncols2;
|
constexpr int ncols = ncols1 * ncols2;
|
||||||
constexpr int cols_per_warp = T_B_KQ::I;
|
constexpr int cols_per_warp = T_B_KQ::I;
|
||||||
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
|
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
|
||||||
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
|
||||||
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
|
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
|
||||||
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
|
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
|
||||||
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
|
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
|
||||||
@@ -470,7 +470,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
|
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
|
||||||
load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
|
load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
|
||||||
@@ -482,8 +481,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||||||
T_A_KQ K_A;
|
T_A_KQ K_A;
|
||||||
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
|
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
|
||||||
|
|
||||||
// Wide version of KQ_C is column-major => swap A and B.
|
if constexpr (cols_per_warp == 8) {
|
||||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
|
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
||||||
|
} else {
|
||||||
|
// Wide version of KQ_C is column-major
|
||||||
|
#if defined(AMD_WMMA_AVAILABLE)
|
||||||
|
// RDNA matrix C is column-major.
|
||||||
|
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
||||||
|
#else
|
||||||
|
// swap A and B for CUDA.
|
||||||
|
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
|
||||||
|
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -844,7 +853,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||||||
|
|
||||||
constexpr int cols_per_warp = T_B_KQ::I;
|
constexpr int cols_per_warp = T_B_KQ::I;
|
||||||
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
|
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
|
||||||
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
|
||||||
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
|
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
|
||||||
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
|
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
|
||||||
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
|
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
|
||||||
@@ -1356,6 +1365,13 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
#ifdef VOLTA_MMA_AVAILABLE
|
||||||
|
if (ncols1*ncols2 < 32) {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#endif // VOLTA_MMA_AVAILABLE
|
||||||
|
|
||||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||||
if (ncols1*ncols2 > 32) {
|
if (ncols1*ncols2 > 32) {
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
@@ -1589,8 +1605,7 @@ extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
|
|||||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
|
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
|
||||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
|
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
|
||||||
|
|
||||||
// GLM 4.7 Flash uses gqa_ratio 4:
|
// For GLM 4.7 Flash
|
||||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
|
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
||||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
||||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
||||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
|
||||||
|
|||||||
Reference in New Issue
Block a user