diff --git a/convert/convert_glm4moelite.go b/convert/convert_glm4moelite.go index a74a2fee6..492266e6c 100644 --- a/convert/convert_glm4moelite.go +++ b/convert/convert_glm4moelite.go @@ -6,6 +6,10 @@ import ( "log/slog" "regexp" "strconv" + "strings" + + "github.com/pdevine/tensor" + "github.com/pdevine/tensor/native" "github.com/ollama/ollama/fs/ggml" ) @@ -69,6 +73,9 @@ func (p *glm4MoeLiteModel) KV(t *Tokenizer) KV { kv["glm4moelite.rope.dimension_count"] = p.QKRopeHeadDim kv["glm4moelite.rope.freq_base"] = cmp.Or(p.RopeTheta, float32(1000000.0)) + kv["glm4moelite.attention.key_length_mla"] = p.KVLoraRank + p.QKRopeHeadDim + kv["glm4moelite.attention.value_length_mla"] = p.KVLoraRank + kv["tokenizer.ggml.pre"] = "glm4" return kv @@ -100,6 +107,67 @@ func (p *glm4MoeLiteModel) Replacements() []string { } } +// repackKVB extracts K or V from the combined KV_B tensor for MLA absorption. +// K output row-major: [n_head, kv_lora_rank, qk_nope] -> GGML ne[]={qk_nope, kv_lora_rank, n_head} +// V output row-major: [n_head, v_head, kv_lora_rank] -> GGML ne[]={kv_lora_rank, v_head, n_head} +func (p *glm4MoeLiteModel) repackKVB(extractK bool, kvFirst bool, numHeads int) Repacker { + qkNope := int(p.QKNopeHeadDim) + vHeadDim := int(p.VHeadDim) + kvLoraRank := int(p.KVLoraRank) + kvPerHead := qkNope + vHeadDim + + return func(_ string, data []float32, shape []uint64) ([]float32, error) { + dims := make([]int, len(shape)) + for i := range shape { + dims[i] = int(shape[i]) + } + + var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data)) + var err error + + // Normalize to [n_head * (qk_nope + v_head), kv_lora_rank] layout + if kvFirst { + tt, err = tensor.Transpose(tt, 1, 0) + if err != nil { + return nil, err + } + tt = tensor.Materialize(tt) + } + + // Reshape to [n_head, qk_nope + v_head, kv_lora_rank] + if err := tt.Reshape(numHeads, kvPerHead, kvLoraRank); err != nil { + return nil, err + } + + if extractK { + // Slice K: [n_head, qk_nope, kv_lora_rank] + tt, err = tt.Slice(nil, tensor.S(0, qkNope), nil) + if err != nil { + return nil, err + } + tt = tensor.Materialize(tt) + // Transpose to [n_head, kv_lora_rank, qk_nope] + tt, err = tensor.Transpose(tt, 0, 2, 1) + if err != nil { + return nil, err + } + tt = tensor.Materialize(tt) + } else { + // Slice V: [n_head, v_head, kv_lora_rank] - already correct layout + tt, err = tt.Slice(nil, tensor.S(qkNope, kvPerHead), nil) + if err != nil { + return nil, err + } + tt = tensor.Materialize(tt) + } + + if err := tt.Reshape(tt.Shape().TotalSize()); err != nil { + return nil, err + } + return native.VectorF32(tt.(*tensor.Dense)) + } +} + func (p *glm4MoeLiteModel) Tensors(s []Tensor) (out []*ggml.Tensor) { merges := make([]merge, p.HiddenLayers*3) for i := range p.HiddenLayers { @@ -139,6 +207,52 @@ func (p *glm4MoeLiteModel) Tensors(s []Tensor) (out []*ggml.Tensor) { slog.Debug("skipping layer", "name", t.Name()) continue } + + // Split attn_kv_b into separate attn_k_b and attn_v_b for MLA absorption + if strings.HasSuffix(t.Name(), ".attn_kv_b.weight") { + qkNope := int(p.QKNopeHeadDim) + vHeadDim := int(p.VHeadDim) + kvLoraRank := int(p.KVLoraRank) + kvPerHead := qkNope + vHeadDim + numHeads := int(p.NumAttentionHeads) + kvFirst := true + if len(t.Shape()) == 2 { + switch { + case int(t.Shape()[0]) == kvLoraRank: + if kvPerHead > 0 && int(t.Shape()[1])%kvPerHead == 0 { + numHeads = int(t.Shape()[1]) / kvPerHead + } + kvFirst = true + case int(t.Shape()[1]) == kvLoraRank: + if kvPerHead > 0 && int(t.Shape()[0])%kvPerHead == 0 { + numHeads = int(t.Shape()[0]) / kvPerHead + } + kvFirst = false + default: + slog.Warn("glm4moelite: unexpected attn_kv_b layout", "name", t.Name(), "shape", t.Shape()) + } + } + + kTensor := t.Clone() + kTensor.SetRepacker(p.repackKVB(true, kvFirst, numHeads)) + out = append(out, &ggml.Tensor{ + Name: strings.Replace(t.Name(), "attn_kv_b", "attn_k_b", 1), + Kind: t.Kind(), + Shape: []uint64{uint64(numHeads), uint64(kvLoraRank), uint64(qkNope)}, + WriterTo: kTensor, + }) + + vTensor := t.Clone() + vTensor.SetRepacker(p.repackKVB(false, kvFirst, numHeads)) + out = append(out, &ggml.Tensor{ + Name: strings.Replace(t.Name(), "attn_kv_b", "attn_v_b", 1), + Kind: t.Kind(), + Shape: []uint64{uint64(numHeads), uint64(vHeadDim), uint64(kvLoraRank)}, + WriterTo: vTensor, + }) + continue + } + out = append(out, &ggml.Tensor{ Name: t.Name(), Kind: t.Kind(), diff --git a/llama/patches/0032-ggml-enable-MLA-flash-attention-for-GLM-4.7-flash.patch b/llama/patches/0032-ggml-enable-MLA-flash-attention-for-GLM-4.7-flash.patch new file mode 100644 index 000000000..27a4a42f4 --- /dev/null +++ b/llama/patches/0032-ggml-enable-MLA-flash-attention-for-GLM-4.7-flash.patch @@ -0,0 +1,248 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: nobody <> +Date: Fri, 23 Jan 2026 12:42:53 -0800 +Subject: [PATCH] ggml: enable MLA flash attention for GLM-4.7-flash + +Add support for gqa_ratio 4 in MLA flash attention kernels. GLM-4.7-flash +uses head size 576 with gqa_ratio 4, which was previously only supported +for gqa_ratio 16 (DeepSeek). + +Metal changes: +- Enable head size 576 for flash attention +- Increase simdgroups to 8 for large heads (>=512) +- Add case 8 kernel dispatch for 8 simdgroups + +CUDA changes: +- Add gqa_ratio 4 support for head 576/512 +- Add tile configs for (576, 512, 4) and (576, 512, 8) +- Add MMA config cases for ncols 4 +- Add template instances for ncols2=4 +--- + ggml/src/ggml-cuda/fattn-mma-f16.cuh | 15 ++++++++++++--- + ggml/src/ggml-cuda/fattn-tile.cuh | 16 ++++++++++++++++ + ggml/src/ggml-cuda/fattn.cu | 12 ++++++++---- + .../fattn-mma-f16-instance-ncols1_16-ncols2_4.cu | 1 + + .../fattn-mma-f16-instance-ncols1_2-ncols2_4.cu | 1 + + .../fattn-mma-f16-instance-ncols1_4-ncols2_4.cu | 1 + + .../fattn-mma-f16-instance-ncols1_8-ncols2_4.cu | 1 + + ggml/src/ggml-metal/ggml-metal-device.m | 8 ++------ + ggml/src/ggml-metal/ggml-metal-ops.cpp | 2 +- + ggml/src/ggml-metal/ggml-metal.metal | 1 + + 10 files changed, 44 insertions(+), 14 deletions(-) + +diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh +index 7bd1044c1..a627302f9 100644 +--- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh ++++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh +@@ -66,7 +66,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true); + +- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false); ++ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 128, 1, false); ++ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); +@@ -80,7 +81,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); + +- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false); ++ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 96, 64, 128, 1, false); ++ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); +@@ -89,7 +91,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co + } + + static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) { +- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false); ++ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 64, 1, false); ++ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false); +@@ -1585,3 +1588,9 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64) + extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); + extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); + extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); ++ ++// GLM 4.7 Flash uses gqa_ratio 4: ++extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4); ++extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); ++extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); ++extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4); +diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh +index 7c4d6fe67..682fb366e 100644 +--- a/ggml/src/ggml-cuda/fattn-tile.cuh ++++ b/ggml/src/ggml-cuda/fattn-tile.cuh +@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) + ++ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 2, 64, 64) ++ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) + + return 0; +@@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) + ++ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 2, 32, 64) ++ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64) + + return 0; +@@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128) + ++ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 2, 64, 64) ++ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64) + +@@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128) + ++ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 4, 64, 64) ++ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 4, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64) + +@@ -1187,6 +1195,14 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } ++ if (use_gqa_opt && gqa_ratio % 8 == 0) { ++ launch_fattn_tile_switch_ncols1(ctx, dst); ++ return; ++ } ++ if (use_gqa_opt && gqa_ratio % 4 == 0) { ++ launch_fattn_tile_switch_ncols1(ctx, dst); ++ return; ++ } + } + + if constexpr (DV <= 256) { +diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu +index 015540666..1693479cb 100644 +--- a/ggml/src/ggml-cuda/fattn.cu ++++ b/ggml/src/ggml-cuda/fattn.cu +@@ -111,7 +111,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst); + break; + case 576: { +- // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels. ++ // For Deepseek/GLM4, go straight to the ncols1 switch to avoid compiling unnecessary kernels. + GGML_ASSERT(V->ne[0] == 512); + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); +@@ -121,8 +121,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg + + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; +- GGML_ASSERT(gqa_ratio % 16 == 0); +- ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); ++ GGML_ASSERT(gqa_ratio % 4 == 0); ++ if (gqa_ratio % 16 == 0) { ++ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); ++ } else { ++ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); ++ } + } break; + default: + GGML_ABORT("fatal error"); +@@ -251,7 +255,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const + if (V->ne[0] != 512) { + return BEST_FATTN_KERNEL_NONE; + } +- if (!gqa_opt_applies || gqa_ratio % 16 != 0) { ++ if (!gqa_opt_applies || gqa_ratio % 4 != 0) { + return BEST_FATTN_KERNEL_NONE; + } + break; +diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +index 2074e954a..517993cb0 100644 +--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu ++++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4); + DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4); + DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4); + DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4); ++DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4); +diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +index 24c64cf00..97b19c67a 100644 +--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu ++++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4); + DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4); + DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4); + DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4); ++DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4); +diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +index 1ada657f1..989626dfa 100644 +--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu ++++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4); + DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4); + DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4); + DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4); ++DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); +diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +index 86d4ffae2..173de7aac 100644 +--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu ++++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4); + DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4); + DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4); + DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4); ++DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); +diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m +index f24270bb1..7b5ee968c 100644 +--- a/ggml/src/ggml-metal/ggml-metal-device.m ++++ b/ggml/src/ggml-metal/ggml-metal-device.m +@@ -1071,12 +1071,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te + op->src[0]->ne[0] != 112 && + op->src[0]->ne[0] != 128 && + op->src[0]->ne[0] != 192 && +- op->src[0]->ne[0] != 256) { +- return false; +- } +- if (op->src[0]->ne[0] == 576) { +- // DeepSeek sizes +- // TODO: disabled for now, until optmized ++ op->src[0]->ne[0] != 256 && ++ op->src[0]->ne[0] != 576) { + return false; + } + if (op->src[1]->type != op->src[2]->type) { +diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp +index e99c1763f..80864f303 100644 +--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp ++++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp +@@ -2456,7 +2456,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { + + // simdgroups per threadgroup (a.k.a. warps) + //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; +- int32_t nsg = 4; ++ int32_t nsg = ne00 >= 512 ? 8 : 4; + + const size_t smem = FATTN_SMEM(nsg); + +diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal +index c98d269d1..d33c16079 100644 +--- a/ggml/src/ggml-metal/ggml-metal.metal ++++ b/ggml/src/ggml-metal/ggml-metal.metal +@@ -6166,6 +6166,7 @@ kernel void kernel_flash_attn_ext( + //case 1: kernel_flash_attn_ext_impl(FWD_ARGS); break; + //case 2: kernel_flash_attn_ext_impl(FWD_ARGS); break; + case 4: kernel_flash_attn_ext_impl(FWD_ARGS); break; ++ case 8: kernel_flash_attn_ext_impl(FWD_ARGS); break; + } + #undef FWD_TMPL + #undef FWD_ARGS diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 7bd1044c1..a627302f9 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -66,7 +66,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); @@ -80,7 +81,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); @@ -89,7 +91,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co } static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) { - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false); @@ -1585,3 +1588,9 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64) extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); + +// GLM 4.7 Flash uses gqa_ratio 4: +extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh index 7c4d6fe67..682fb366e 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh @@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) return 0; @@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64) return 0; @@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64) @@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 4, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 4, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64) @@ -1187,6 +1195,14 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm launch_fattn_tile_switch_ncols1(ctx, dst); return; } + if (use_gqa_opt && gqa_ratio % 8 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + if (use_gqa_opt && gqa_ratio % 4 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } } if constexpr (DV <= 256) { diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu b/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu index 015540666..1693479cb 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu @@ -111,7 +111,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst); break; case 576: { - // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels. + // For Deepseek/GLM4, go straight to the ncols1 switch to avoid compiling unnecessary kernels. GGML_ASSERT(V->ne[0] == 512); float max_bias = 0.0f; memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); @@ -121,8 +121,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; - GGML_ASSERT(gqa_ratio % 16 == 0); - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + GGML_ASSERT(gqa_ratio % 4 == 0); + if (gqa_ratio % 16 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); + } } break; default: GGML_ABORT("fatal error"); @@ -251,7 +255,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const if (V->ne[0] != 512) { return BEST_FATTN_KERNEL_NONE; } - if (!gqa_opt_applies || gqa_ratio % 16 != 0) { + if (!gqa_opt_applies || gqa_ratio % 4 != 0) { return BEST_FATTN_KERNEL_NONE; } break; diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu index 2074e954a..517993cb0 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4); +DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu index 24c64cf00..97b19c67a 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4); +DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu index 1ada657f1..989626dfa 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4); +DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu index 86d4ffae2..173de7aac 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4); +DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m index f24270bb1..7b5ee968c 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m @@ -1071,12 +1071,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te op->src[0]->ne[0] != 112 && op->src[0]->ne[0] != 128 && op->src[0]->ne[0] != 192 && - op->src[0]->ne[0] != 256) { - return false; - } - if (op->src[0]->ne[0] == 576) { - // DeepSeek sizes - // TODO: disabled for now, until optmized + op->src[0]->ne[0] != 256 && + op->src[0]->ne[0] != 576) { return false; } if (op->src[1]->type != op->src[2]->type) { diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal index 13c6715ba..3235a18eb 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal @@ -8967,6 +8967,7 @@ kernel void kernel_flash_attn_ext( //case 1: kernel_flash_attn_ext_impl(FWD_ARGS); break; //case 2: kernel_flash_attn_ext_impl(FWD_ARGS); break; case 4: kernel_flash_attn_ext_impl(FWD_ARGS); break; + case 8: kernel_flash_attn_ext_impl(FWD_ARGS); break; } #undef FWD_TMPL #undef FWD_ARGS diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp index e99c1763f..80864f303 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -2456,7 +2456,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { // simdgroups per threadgroup (a.k.a. warps) //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; - int32_t nsg = 4; + int32_t nsg = ne00 >= 512 ? 8 : 4; const size_t smem = FATTN_SMEM(nsg); diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal index c98d269d1..d33c16079 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal @@ -6166,6 +6166,7 @@ kernel void kernel_flash_attn_ext( //case 1: kernel_flash_attn_ext_impl(FWD_ARGS); break; //case 2: kernel_flash_attn_ext_impl(FWD_ARGS); break; case 4: kernel_flash_attn_ext_impl(FWD_ARGS); break; + case 8: kernel_flash_attn_ext_impl(FWD_ARGS); break; } #undef FWD_TMPL #undef FWD_ARGS diff --git a/model/model.go b/model/model.go index 0af16da80..b77d7175d 100644 --- a/model/model.go +++ b/model/model.go @@ -39,6 +39,13 @@ type Model interface { Config() config } +// Validator is an optional interface that models can implement to perform +// validation after tensors have been loaded. If validation fails, model +// loading will fail with the returned error. +type Validator interface { + Validate() error +} + // MultimodalProcessor must be implemented by multimodal models. type MultimodalProcessor interface { // EncodeMultimodal processes a single input (such as an image) and @@ -116,6 +123,13 @@ func New(modelPath string, params ml.BackendParams) (Model, error) { base := Base{b: b, config: m.Config()} v := reflect.ValueOf(m) v.Elem().Set(populateFields(base, v.Elem())) + + if validator, ok := m.(Validator); ok { + if err := validator.Validate(); err != nil { + return nil, err + } + } + return m, nil } diff --git a/model/models/glm4moelite/model.go b/model/models/glm4moelite/model.go index 2e51f7d56..2c55bbfb2 100644 --- a/model/models/glm4moelite/model.go +++ b/model/models/glm4moelite/model.go @@ -1,6 +1,7 @@ package glm4moelite import ( + "errors" "math" "github.com/ollama/ollama/fs" @@ -11,6 +12,8 @@ import ( "github.com/ollama/ollama/model/input" ) +var ErrOldModelFormat = errors.New("this model uses a weight format that is no longer supported; please re-download it") + type Options struct { numExpertsUsed int numExperts int @@ -47,7 +50,9 @@ type Attention struct { KVA *nn.Linear `gguf:"attn_kv_a_mqa"` KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"` - KVB *nn.Linear `gguf:"attn_kv_b"` + + KB *nn.Linear `gguf:"attn_k_b"` + VB *nn.Linear `gguf:"attn_v_b"` Output *nn.Linear `gguf:"attn_out,alt:attn_output"` } @@ -78,15 +83,16 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor qRot := opts.applyRotaryPositionEmbeddings(ctx, queryChunks[1], positions) kRot = opts.applyRotaryPositionEmbeddings(ctx, kRot, positions) kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps) - kPass = attn.KVB.Forward(ctx, kPass) - kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength) - kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim) + // MLA absorption: absorb K projection into query + qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3) + qPassAbsorb := attn.KB.Forward(ctx, qPass).Permute(ctx, 0, 2, 1, 3) + query = qRot.Concat(ctx, qPassAbsorb, 0) - kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1)) - query = qRot.Concat(ctx, queryChunks[0], 0) - key := kRot.Concat(ctx, kvChunks[0], 0) - attention := nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache) + kPass = kPass.Reshape(ctx, opts.kvLoraRank, 1, seqLength) + key := kRot.Concat(ctx, kPass, 0) + + attention := nn.AttentionWithVMLA(ctx, query, key, kPass, nil, attn.VB.Weight, opts.kqScale, cache) attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength) return attn.Output.Forward(ctx, attention) @@ -217,8 +223,12 @@ func New(c fs.Config) (model.Model, error) { keyLength := int(c.Uint("attention.key_length")) valueLength := int(c.Uint("attention.value_length")) + kvLoraRank := int(c.Uint("attention.kv_lora_rank")) + qkRopeHeadDim := int(c.Uint("rope.dimension_count")) - kqScale := 1.0 / math.Sqrt(float64(keyLength)) + // For MLA absorption, the effective key dimension is kvLoraRank + qkRopeHeadDim + mlaKeyLength := kvLoraRank + qkRopeHeadDim + kqScale := 1.0 / math.Sqrt(float64(mlaKeyLength)) var pre []string switch c.String("tokenizer.ggml.pre") { @@ -279,6 +289,15 @@ func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil } +func (m *Model) Validate() error { + for _, layer := range m.Layers { + if layer.Attention != nil && (layer.Attention.KB == nil || layer.Attention.VB == nil) { + return ErrOldModelFormat + } + } + return nil +} + func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions)) diff --git a/model/models/glm4moelite/model_test.go b/model/models/glm4moelite/model_test.go new file mode 100644 index 000000000..fbc3b460d --- /dev/null +++ b/model/models/glm4moelite/model_test.go @@ -0,0 +1,73 @@ +package glm4moelite + +import ( + "testing" + + "github.com/ollama/ollama/ml/nn" +) + +func TestValidate(t *testing.T) { + tests := []struct { + name string + model *Model + wantErr bool + }{ + { + name: "valid model with KB and VB", + model: &Model{ + Layers: []Layer{ + {Attention: &Attention{KB: &nn.Linear{}, VB: &nn.Linear{}}}, + }, + }, + wantErr: false, + }, + { + name: "missing KB", + model: &Model{ + Layers: []Layer{ + {Attention: &Attention{VB: &nn.Linear{}}}, + }, + }, + wantErr: true, + }, + { + name: "missing VB", + model: &Model{ + Layers: []Layer{ + {Attention: &Attention{KB: &nn.Linear{}}}, + }, + }, + wantErr: true, + }, + { + name: "missing both KB and VB", + model: &Model{ + Layers: []Layer{ + {Attention: &Attention{}}, + }, + }, + wantErr: true, + }, + { + name: "nil Attention is ok", + model: &Model{ + Layers: []Layer{ + {Attention: nil}, + }, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.model.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr) + } + if tt.wantErr && err != ErrOldModelFormat { + t.Errorf("Validate() error = %v, want %v", err, ErrOldModelFormat) + } + }) + } +}