mirror of
https://github.com/ollama/ollama.git
synced 2026-01-29 07:12:03 +03:00
Re-apply "model: add MLA absorption for glm4moelite" with fix (#13870)
The nvidia_fp32 config for (576, 512) head sizes had nbatch_fa=32, which caused zero-sized arrays when computing array dimensions: nbatch_fa / (np * warp_size) = 32 / (2 * 32) = 0 This resulted in CUDA compilation failures on CUDA 12 (Windows and Linux arm64): - "static assertion failed with nbatch_fa % (np*warp_size) != 0" - "the size of an array must be greater than zero" Fix by changing nbatch_fa from 32 to 64 for all (576, 512) configs in the nvidia_fp32 function, matching the nvidia_fp16 and AMD configs.
This commit is contained in:
@@ -6,6 +6,10 @@ import (
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
@@ -69,6 +73,9 @@ func (p *glm4MoeLiteModel) KV(t *Tokenizer) KV {
|
||||
kv["glm4moelite.rope.dimension_count"] = p.QKRopeHeadDim
|
||||
kv["glm4moelite.rope.freq_base"] = cmp.Or(p.RopeTheta, float32(1000000.0))
|
||||
|
||||
kv["glm4moelite.attention.key_length_mla"] = p.KVLoraRank + p.QKRopeHeadDim
|
||||
kv["glm4moelite.attention.value_length_mla"] = p.KVLoraRank
|
||||
|
||||
kv["tokenizer.ggml.pre"] = "glm4"
|
||||
|
||||
return kv
|
||||
@@ -100,6 +107,67 @@ func (p *glm4MoeLiteModel) Replacements() []string {
|
||||
}
|
||||
}
|
||||
|
||||
// repackKVB extracts K or V from the combined KV_B tensor for MLA absorption.
|
||||
// K output row-major: [n_head, kv_lora_rank, qk_nope] -> GGML ne[]={qk_nope, kv_lora_rank, n_head}
|
||||
// V output row-major: [n_head, v_head, kv_lora_rank] -> GGML ne[]={kv_lora_rank, v_head, n_head}
|
||||
func (p *glm4MoeLiteModel) repackKVB(extractK bool, kvFirst bool, numHeads int) Repacker {
|
||||
qkNope := int(p.QKNopeHeadDim)
|
||||
vHeadDim := int(p.VHeadDim)
|
||||
kvLoraRank := int(p.KVLoraRank)
|
||||
kvPerHead := qkNope + vHeadDim
|
||||
|
||||
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i := range shape {
|
||||
dims[i] = int(shape[i])
|
||||
}
|
||||
|
||||
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
var err error
|
||||
|
||||
// Normalize to [n_head * (qk_nope + v_head), kv_lora_rank] layout
|
||||
if kvFirst {
|
||||
tt, err = tensor.Transpose(tt, 1, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tt = tensor.Materialize(tt)
|
||||
}
|
||||
|
||||
// Reshape to [n_head, qk_nope + v_head, kv_lora_rank]
|
||||
if err := tt.Reshape(numHeads, kvPerHead, kvLoraRank); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if extractK {
|
||||
// Slice K: [n_head, qk_nope, kv_lora_rank]
|
||||
tt, err = tt.Slice(nil, tensor.S(0, qkNope), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tt = tensor.Materialize(tt)
|
||||
// Transpose to [n_head, kv_lora_rank, qk_nope]
|
||||
tt, err = tensor.Transpose(tt, 0, 2, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tt = tensor.Materialize(tt)
|
||||
} else {
|
||||
// Slice V: [n_head, v_head, kv_lora_rank] - already correct layout
|
||||
tt, err = tt.Slice(nil, tensor.S(qkNope, kvPerHead), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tt = tensor.Materialize(tt)
|
||||
}
|
||||
|
||||
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return native.VectorF32(tt.(*tensor.Dense))
|
||||
}
|
||||
}
|
||||
|
||||
func (p *glm4MoeLiteModel) Tensors(s []Tensor) (out []*ggml.Tensor) {
|
||||
merges := make([]merge, p.HiddenLayers*3)
|
||||
for i := range p.HiddenLayers {
|
||||
@@ -139,6 +207,52 @@ func (p *glm4MoeLiteModel) Tensors(s []Tensor) (out []*ggml.Tensor) {
|
||||
slog.Debug("skipping layer", "name", t.Name())
|
||||
continue
|
||||
}
|
||||
|
||||
// Split attn_kv_b into separate attn_k_b and attn_v_b for MLA absorption
|
||||
if strings.HasSuffix(t.Name(), ".attn_kv_b.weight") {
|
||||
qkNope := int(p.QKNopeHeadDim)
|
||||
vHeadDim := int(p.VHeadDim)
|
||||
kvLoraRank := int(p.KVLoraRank)
|
||||
kvPerHead := qkNope + vHeadDim
|
||||
numHeads := int(p.NumAttentionHeads)
|
||||
kvFirst := true
|
||||
if len(t.Shape()) == 2 {
|
||||
switch {
|
||||
case int(t.Shape()[0]) == kvLoraRank:
|
||||
if kvPerHead > 0 && int(t.Shape()[1])%kvPerHead == 0 {
|
||||
numHeads = int(t.Shape()[1]) / kvPerHead
|
||||
}
|
||||
kvFirst = true
|
||||
case int(t.Shape()[1]) == kvLoraRank:
|
||||
if kvPerHead > 0 && int(t.Shape()[0])%kvPerHead == 0 {
|
||||
numHeads = int(t.Shape()[0]) / kvPerHead
|
||||
}
|
||||
kvFirst = false
|
||||
default:
|
||||
slog.Warn("glm4moelite: unexpected attn_kv_b layout", "name", t.Name(), "shape", t.Shape())
|
||||
}
|
||||
}
|
||||
|
||||
kTensor := t.Clone()
|
||||
kTensor.SetRepacker(p.repackKVB(true, kvFirst, numHeads))
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(t.Name(), "attn_kv_b", "attn_k_b", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: []uint64{uint64(numHeads), uint64(kvLoraRank), uint64(qkNope)},
|
||||
WriterTo: kTensor,
|
||||
})
|
||||
|
||||
vTensor := t.Clone()
|
||||
vTensor.SetRepacker(p.repackKVB(false, kvFirst, numHeads))
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(t.Name(), "attn_kv_b", "attn_v_b", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: []uint64{uint64(numHeads), uint64(vHeadDim), uint64(kvLoraRank)},
|
||||
WriterTo: vTensor,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
|
||||
@@ -0,0 +1,251 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: nobody <>
|
||||
Date: Sat, 24 Jan 2026 02:31:01 +0000
|
||||
Subject: [PATCH] ggml: enable MLA flash attention for GLM-4.7-flash
|
||||
|
||||
Add support for gqa_ratio 4 in MLA flash attention kernels. GLM-4.7-flash
|
||||
uses head size 576 with gqa_ratio 4, which was previously only supported
|
||||
for gqa_ratio 16 (DeepSeek).
|
||||
|
||||
Metal changes:
|
||||
- Enable head size 576 for flash attention
|
||||
- Increase simdgroups to 8 for large heads (>=512)
|
||||
- Add case 8 kernel dispatch for 8 simdgroups
|
||||
|
||||
CUDA changes:
|
||||
- Add gqa_ratio 4 support for head 576/512
|
||||
- Add tile configs for (576, 512, 4) and (576, 512, 8)
|
||||
- Add MMA config cases for ncols 4
|
||||
- Add template instances for ncols2=4
|
||||
- Fix nbatch_fa values in nvidia_fp32 config (32->64)
|
||||
---
|
||||
ggml/src/ggml-cuda/fattn-mma-f16.cuh | 15 ++++++++++++---
|
||||
ggml/src/ggml-cuda/fattn-tile.cuh | 18 +++++++++++++++++-
|
||||
ggml/src/ggml-cuda/fattn.cu | 12 ++++++++----
|
||||
...attn-mma-f16-instance-ncols1_16-ncols2_4.cu | 1 +
|
||||
...fattn-mma-f16-instance-ncols1_2-ncols2_4.cu | 1 +
|
||||
...fattn-mma-f16-instance-ncols1_4-ncols2_4.cu | 1 +
|
||||
...fattn-mma-f16-instance-ncols1_8-ncols2_4.cu | 1 +
|
||||
ggml/src/ggml-metal/ggml-metal-device.m | 8 ++------
|
||||
ggml/src/ggml-metal/ggml-metal-ops.cpp | 2 +-
|
||||
ggml/src/ggml-metal/ggml-metal.metal | 1 +
|
||||
10 files changed, 45 insertions(+), 15 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
|
||||
index 7bd1044c1..a627302f9 100644
|
||||
--- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh
|
||||
+++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
|
||||
@@ -66,7 +66,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
|
||||
|
||||
- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false);
|
||||
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 128, 1, false);
|
||||
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
|
||||
@@ -80,7 +81,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
|
||||
- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
|
||||
@@ -89,7 +91,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
}
|
||||
|
||||
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
|
||||
- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false);
|
||||
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 64, 1, false);
|
||||
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false);
|
||||
@@ -1585,3 +1588,9 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
|
||||
+
|
||||
+// GLM 4.7 Flash uses gqa_ratio 4:
|
||||
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
|
||||
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
||||
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
||||
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
||||
diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh
|
||||
index 7c4d6fe67..6389ba5c4 100644
|
||||
--- a/ggml/src/ggml-cuda/fattn-tile.cuh
|
||||
+++ b/ggml/src/ggml-cuda/fattn-tile.cuh
|
||||
@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
|
||||
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 2, 64, 64)
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
|
||||
return 0;
|
||||
@@ -122,7 +124,9 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
|
||||
|
||||
- GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 2, 64, 64)
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
|
||||
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 2, 64, 64)
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
|
||||
|
||||
@@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
|
||||
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 4, 64, 64)
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 4, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)
|
||||
|
||||
@@ -1187,6 +1195,14 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
+ if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
||||
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
|
||||
+ return;
|
||||
+ }
|
||||
+ if (use_gqa_opt && gqa_ratio % 4 == 0) {
|
||||
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
|
||||
+ return;
|
||||
+ }
|
||||
}
|
||||
|
||||
if constexpr (DV <= 256) {
|
||||
diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu
|
||||
index 015540666..1693479cb 100644
|
||||
--- a/ggml/src/ggml-cuda/fattn.cu
|
||||
+++ b/ggml/src/ggml-cuda/fattn.cu
|
||||
@@ -111,7 +111,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
|
||||
break;
|
||||
case 576: {
|
||||
- // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
|
||||
+ // For Deepseek/GLM4, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
|
||||
GGML_ASSERT(V->ne[0] == 512);
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
@@ -121,8 +121,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
- GGML_ASSERT(gqa_ratio % 16 == 0);
|
||||
- ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
+ GGML_ASSERT(gqa_ratio % 4 == 0);
|
||||
+ if (gqa_ratio % 16 == 0) {
|
||||
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
+ } else {
|
||||
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
|
||||
+ }
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -251,7 +255,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
if (V->ne[0] != 512) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
- if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
|
||||
+ if (!gqa_opt_applies || gqa_ratio % 4 != 0) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
|
||||
index 2074e954a..517993cb0 100644
|
||||
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
|
||||
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
|
||||
+DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
||||
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
|
||||
index 24c64cf00..97b19c67a 100644
|
||||
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
|
||||
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
|
||||
+DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
|
||||
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
|
||||
index 1ada657f1..989626dfa 100644
|
||||
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
|
||||
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
|
||||
+DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
||||
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
|
||||
index 86d4ffae2..173de7aac 100644
|
||||
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
|
||||
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
|
||||
+DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m
|
||||
index f24270bb1..7b5ee968c 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-device.m
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-device.m
|
||||
@@ -1071,12 +1071,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
op->src[0]->ne[0] != 112 &&
|
||||
op->src[0]->ne[0] != 128 &&
|
||||
op->src[0]->ne[0] != 192 &&
|
||||
- op->src[0]->ne[0] != 256) {
|
||||
- return false;
|
||||
- }
|
||||
- if (op->src[0]->ne[0] == 576) {
|
||||
- // DeepSeek sizes
|
||||
- // TODO: disabled for now, until optmized
|
||||
+ op->src[0]->ne[0] != 256 &&
|
||||
+ op->src[0]->ne[0] != 576) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[1]->type != op->src[2]->type) {
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
index e99c1763f..80864f303 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
@@ -2456,7 +2456,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
// simdgroups per threadgroup (a.k.a. warps)
|
||||
//nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
|
||||
- int32_t nsg = 4;
|
||||
+ int32_t nsg = ne00 >= 512 ? 8 : 4;
|
||||
|
||||
const size_t smem = FATTN_SMEM(nsg);
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
index c98d269d1..d33c16079 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
@@ -6166,6 +6166,7 @@ kernel void kernel_flash_attn_ext(
|
||||
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
|
||||
//case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
|
||||
case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
|
||||
+ case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
|
||||
}
|
||||
#undef FWD_TMPL
|
||||
#undef FWD_ARGS
|
||||
@@ -66,7 +66,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
|
||||
@@ -80,7 +81,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
|
||||
@@ -89,7 +91,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
}
|
||||
|
||||
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false);
|
||||
@@ -1585,3 +1588,9 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
|
||||
|
||||
// GLM 4.7 Flash uses gqa_ratio 4:
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
||||
|
||||
@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
|
||||
return 0;
|
||||
@@ -122,7 +124,9 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
|
||||
|
||||
@@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 4, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 4, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)
|
||||
|
||||
@@ -1187,6 +1195,14 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
if (use_gqa_opt && gqa_ratio % 4 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (DV <= 256) {
|
||||
|
||||
12
ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu
vendored
12
ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu
vendored
@@ -111,7 +111,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
|
||||
break;
|
||||
case 576: {
|
||||
// For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
|
||||
// For Deepseek/GLM4, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
|
||||
GGML_ASSERT(V->ne[0] == 512);
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
@@ -121,8 +121,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
GGML_ASSERT(gqa_ratio % 16 == 0);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
GGML_ASSERT(gqa_ratio % 4 == 0);
|
||||
if (gqa_ratio % 16 == 0) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -251,7 +255,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
if (V->ne[0] != 512) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
|
||||
if (!gqa_opt_applies || gqa_ratio % 4 != 0) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
||||
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
|
||||
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
||||
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
||||
|
||||
@@ -1071,12 +1071,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
op->src[0]->ne[0] != 112 &&
|
||||
op->src[0]->ne[0] != 128 &&
|
||||
op->src[0]->ne[0] != 192 &&
|
||||
op->src[0]->ne[0] != 256) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->ne[0] == 576) {
|
||||
// DeepSeek sizes
|
||||
// TODO: disabled for now, until optmized
|
||||
op->src[0]->ne[0] != 256 &&
|
||||
op->src[0]->ne[0] != 576) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[1]->type != op->src[2]->type) {
|
||||
|
||||
@@ -8967,6 +8967,7 @@ kernel void kernel_flash_attn_ext(
|
||||
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
|
||||
//case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
|
||||
case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
|
||||
case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
|
||||
}
|
||||
#undef FWD_TMPL
|
||||
#undef FWD_ARGS
|
||||
|
||||
@@ -2456,7 +2456,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
// simdgroups per threadgroup (a.k.a. warps)
|
||||
//nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
|
||||
int32_t nsg = 4;
|
||||
int32_t nsg = ne00 >= 512 ? 8 : 4;
|
||||
|
||||
const size_t smem = FATTN_SMEM(nsg);
|
||||
|
||||
|
||||
@@ -6166,6 +6166,7 @@ kernel void kernel_flash_attn_ext(
|
||||
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
|
||||
//case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
|
||||
case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
|
||||
case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
|
||||
}
|
||||
#undef FWD_TMPL
|
||||
#undef FWD_ARGS
|
||||
|
||||
@@ -39,6 +39,13 @@ type Model interface {
|
||||
Config() config
|
||||
}
|
||||
|
||||
// Validator is an optional interface that models can implement to perform
|
||||
// validation after tensors have been loaded. If validation fails, model
|
||||
// loading will fail with the returned error.
|
||||
type Validator interface {
|
||||
Validate() error
|
||||
}
|
||||
|
||||
// MultimodalProcessor must be implemented by multimodal models.
|
||||
type MultimodalProcessor interface {
|
||||
// EncodeMultimodal processes a single input (such as an image) and
|
||||
@@ -116,6 +123,13 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
|
||||
base := Base{b: b, config: m.Config()}
|
||||
v := reflect.ValueOf(m)
|
||||
v.Elem().Set(populateFields(base, v.Elem()))
|
||||
|
||||
if validator, ok := m.(Validator); ok {
|
||||
if err := validator.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package glm4moelite
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
@@ -11,6 +12,8 @@ import (
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
var ErrOldModelFormat = errors.New("this model uses a weight format that is no longer supported; please re-download it")
|
||||
|
||||
type Options struct {
|
||||
numExpertsUsed int
|
||||
numExperts int
|
||||
@@ -47,7 +50,9 @@ type Attention struct {
|
||||
|
||||
KVA *nn.Linear `gguf:"attn_kv_a_mqa"`
|
||||
KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"`
|
||||
KVB *nn.Linear `gguf:"attn_kv_b"`
|
||||
|
||||
KB *nn.Linear `gguf:"attn_k_b"`
|
||||
VB *nn.Linear `gguf:"attn_v_b"`
|
||||
|
||||
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
|
||||
}
|
||||
@@ -78,15 +83,16 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
|
||||
qRot := opts.applyRotaryPositionEmbeddings(ctx, queryChunks[1], positions)
|
||||
kRot = opts.applyRotaryPositionEmbeddings(ctx, kRot, positions)
|
||||
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
|
||||
kPass = attn.KVB.Forward(ctx, kPass)
|
||||
|
||||
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
|
||||
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
|
||||
// MLA absorption: absorb K projection into query
|
||||
qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3)
|
||||
qPassAbsorb := attn.KB.Forward(ctx, qPass).Permute(ctx, 0, 2, 1, 3)
|
||||
query = qRot.Concat(ctx, qPassAbsorb, 0)
|
||||
|
||||
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
|
||||
query = qRot.Concat(ctx, queryChunks[0], 0)
|
||||
key := kRot.Concat(ctx, kvChunks[0], 0)
|
||||
attention := nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
|
||||
kPass = kPass.Reshape(ctx, opts.kvLoraRank, 1, seqLength)
|
||||
key := kRot.Concat(ctx, kPass, 0)
|
||||
|
||||
attention := nn.AttentionWithVMLA(ctx, query, key, kPass, nil, attn.VB.Weight, opts.kqScale, cache)
|
||||
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
|
||||
return attn.Output.Forward(ctx, attention)
|
||||
@@ -217,8 +223,12 @@ func New(c fs.Config) (model.Model, error) {
|
||||
|
||||
keyLength := int(c.Uint("attention.key_length"))
|
||||
valueLength := int(c.Uint("attention.value_length"))
|
||||
kvLoraRank := int(c.Uint("attention.kv_lora_rank"))
|
||||
qkRopeHeadDim := int(c.Uint("rope.dimension_count"))
|
||||
|
||||
kqScale := 1.0 / math.Sqrt(float64(keyLength))
|
||||
// For MLA absorption, the effective key dimension is kvLoraRank + qkRopeHeadDim
|
||||
mlaKeyLength := kvLoraRank + qkRopeHeadDim
|
||||
kqScale := 1.0 / math.Sqrt(float64(mlaKeyLength))
|
||||
|
||||
var pre []string
|
||||
switch c.String("tokenizer.ggml.pre") {
|
||||
@@ -279,6 +289,15 @@ func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||
}
|
||||
|
||||
func (m *Model) Validate() error {
|
||||
for _, layer := range m.Layers {
|
||||
if layer.Attention != nil && (layer.Attention.KB == nil || layer.Attention.VB == nil) {
|
||||
return ErrOldModelFormat
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
|
||||
|
||||
73
model/models/glm4moelite/model_test.go
Normal file
73
model/models/glm4moelite/model_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package glm4moelite
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
func TestValidate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model *Model
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid model with KB and VB",
|
||||
model: &Model{
|
||||
Layers: []Layer{
|
||||
{Attention: &Attention{KB: &nn.Linear{}, VB: &nn.Linear{}}},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing KB",
|
||||
model: &Model{
|
||||
Layers: []Layer{
|
||||
{Attention: &Attention{VB: &nn.Linear{}}},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing VB",
|
||||
model: &Model{
|
||||
Layers: []Layer{
|
||||
{Attention: &Attention{KB: &nn.Linear{}}},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing both KB and VB",
|
||||
model: &Model{
|
||||
Layers: []Layer{
|
||||
{Attention: &Attention{}},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "nil Attention is ok",
|
||||
model: &Model{
|
||||
Layers: []Layer{
|
||||
{Attention: nil},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.model.Validate()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if tt.wantErr && err != ErrOldModelFormat {
|
||||
t.Errorf("Validate() error = %v, want %v", err, ErrOldModelFormat)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user