diff --git a/convert/convert_glm4moelite.go b/convert/convert_glm4moelite.go index 492266e6c..a74a2fee6 100644 --- a/convert/convert_glm4moelite.go +++ b/convert/convert_glm4moelite.go @@ -6,10 +6,6 @@ import ( "log/slog" "regexp" "strconv" - "strings" - - "github.com/pdevine/tensor" - "github.com/pdevine/tensor/native" "github.com/ollama/ollama/fs/ggml" ) @@ -73,9 +69,6 @@ func (p *glm4MoeLiteModel) KV(t *Tokenizer) KV { kv["glm4moelite.rope.dimension_count"] = p.QKRopeHeadDim kv["glm4moelite.rope.freq_base"] = cmp.Or(p.RopeTheta, float32(1000000.0)) - kv["glm4moelite.attention.key_length_mla"] = p.KVLoraRank + p.QKRopeHeadDim - kv["glm4moelite.attention.value_length_mla"] = p.KVLoraRank - kv["tokenizer.ggml.pre"] = "glm4" return kv @@ -107,67 +100,6 @@ func (p *glm4MoeLiteModel) Replacements() []string { } } -// repackKVB extracts K or V from the combined KV_B tensor for MLA absorption. -// K output row-major: [n_head, kv_lora_rank, qk_nope] -> GGML ne[]={qk_nope, kv_lora_rank, n_head} -// V output row-major: [n_head, v_head, kv_lora_rank] -> GGML ne[]={kv_lora_rank, v_head, n_head} -func (p *glm4MoeLiteModel) repackKVB(extractK bool, kvFirst bool, numHeads int) Repacker { - qkNope := int(p.QKNopeHeadDim) - vHeadDim := int(p.VHeadDim) - kvLoraRank := int(p.KVLoraRank) - kvPerHead := qkNope + vHeadDim - - return func(_ string, data []float32, shape []uint64) ([]float32, error) { - dims := make([]int, len(shape)) - for i := range shape { - dims[i] = int(shape[i]) - } - - var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data)) - var err error - - // Normalize to [n_head * (qk_nope + v_head), kv_lora_rank] layout - if kvFirst { - tt, err = tensor.Transpose(tt, 1, 0) - if err != nil { - return nil, err - } - tt = tensor.Materialize(tt) - } - - // Reshape to [n_head, qk_nope + v_head, kv_lora_rank] - if err := tt.Reshape(numHeads, kvPerHead, kvLoraRank); err != nil { - return nil, err - } - - if extractK { - // Slice K: [n_head, qk_nope, kv_lora_rank] - tt, err = tt.Slice(nil, tensor.S(0, qkNope), nil) - if err != nil { - return nil, err - } - tt = tensor.Materialize(tt) - // Transpose to [n_head, kv_lora_rank, qk_nope] - tt, err = tensor.Transpose(tt, 0, 2, 1) - if err != nil { - return nil, err - } - tt = tensor.Materialize(tt) - } else { - // Slice V: [n_head, v_head, kv_lora_rank] - already correct layout - tt, err = tt.Slice(nil, tensor.S(qkNope, kvPerHead), nil) - if err != nil { - return nil, err - } - tt = tensor.Materialize(tt) - } - - if err := tt.Reshape(tt.Shape().TotalSize()); err != nil { - return nil, err - } - return native.VectorF32(tt.(*tensor.Dense)) - } -} - func (p *glm4MoeLiteModel) Tensors(s []Tensor) (out []*ggml.Tensor) { merges := make([]merge, p.HiddenLayers*3) for i := range p.HiddenLayers { @@ -207,52 +139,6 @@ func (p *glm4MoeLiteModel) Tensors(s []Tensor) (out []*ggml.Tensor) { slog.Debug("skipping layer", "name", t.Name()) continue } - - // Split attn_kv_b into separate attn_k_b and attn_v_b for MLA absorption - if strings.HasSuffix(t.Name(), ".attn_kv_b.weight") { - qkNope := int(p.QKNopeHeadDim) - vHeadDim := int(p.VHeadDim) - kvLoraRank := int(p.KVLoraRank) - kvPerHead := qkNope + vHeadDim - numHeads := int(p.NumAttentionHeads) - kvFirst := true - if len(t.Shape()) == 2 { - switch { - case int(t.Shape()[0]) == kvLoraRank: - if kvPerHead > 0 && int(t.Shape()[1])%kvPerHead == 0 { - numHeads = int(t.Shape()[1]) / kvPerHead - } - kvFirst = true - case int(t.Shape()[1]) == kvLoraRank: - if kvPerHead > 0 && int(t.Shape()[0])%kvPerHead == 0 { - numHeads = int(t.Shape()[0]) / kvPerHead - } - kvFirst = false - default: - slog.Warn("glm4moelite: unexpected attn_kv_b layout", "name", t.Name(), "shape", t.Shape()) - } - } - - kTensor := t.Clone() - kTensor.SetRepacker(p.repackKVB(true, kvFirst, numHeads)) - out = append(out, &ggml.Tensor{ - Name: strings.Replace(t.Name(), "attn_kv_b", "attn_k_b", 1), - Kind: t.Kind(), - Shape: []uint64{uint64(numHeads), uint64(kvLoraRank), uint64(qkNope)}, - WriterTo: kTensor, - }) - - vTensor := t.Clone() - vTensor.SetRepacker(p.repackKVB(false, kvFirst, numHeads)) - out = append(out, &ggml.Tensor{ - Name: strings.Replace(t.Name(), "attn_kv_b", "attn_v_b", 1), - Kind: t.Kind(), - Shape: []uint64{uint64(numHeads), uint64(vHeadDim), uint64(kvLoraRank)}, - WriterTo: vTensor, - }) - continue - } - out = append(out, &ggml.Tensor{ Name: t.Name(), Kind: t.Kind(), diff --git a/llama/patches/0032-ggml-enable-MLA-flash-attention-for-GLM-4.7-flash.patch b/llama/patches/0032-ggml-enable-MLA-flash-attention-for-GLM-4.7-flash.patch deleted file mode 100644 index 27a4a42f4..000000000 --- a/llama/patches/0032-ggml-enable-MLA-flash-attention-for-GLM-4.7-flash.patch +++ /dev/null @@ -1,248 +0,0 @@ -From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 -From: nobody <> -Date: Fri, 23 Jan 2026 12:42:53 -0800 -Subject: [PATCH] ggml: enable MLA flash attention for GLM-4.7-flash - -Add support for gqa_ratio 4 in MLA flash attention kernels. GLM-4.7-flash -uses head size 576 with gqa_ratio 4, which was previously only supported -for gqa_ratio 16 (DeepSeek). - -Metal changes: -- Enable head size 576 for flash attention -- Increase simdgroups to 8 for large heads (>=512) -- Add case 8 kernel dispatch for 8 simdgroups - -CUDA changes: -- Add gqa_ratio 4 support for head 576/512 -- Add tile configs for (576, 512, 4) and (576, 512, 8) -- Add MMA config cases for ncols 4 -- Add template instances for ncols2=4 ---- - ggml/src/ggml-cuda/fattn-mma-f16.cuh | 15 ++++++++++++--- - ggml/src/ggml-cuda/fattn-tile.cuh | 16 ++++++++++++++++ - ggml/src/ggml-cuda/fattn.cu | 12 ++++++++---- - .../fattn-mma-f16-instance-ncols1_16-ncols2_4.cu | 1 + - .../fattn-mma-f16-instance-ncols1_2-ncols2_4.cu | 1 + - .../fattn-mma-f16-instance-ncols1_4-ncols2_4.cu | 1 + - .../fattn-mma-f16-instance-ncols1_8-ncols2_4.cu | 1 + - ggml/src/ggml-metal/ggml-metal-device.m | 8 ++------ - ggml/src/ggml-metal/ggml-metal-ops.cpp | 2 +- - ggml/src/ggml-metal/ggml-metal.metal | 1 + - 10 files changed, 44 insertions(+), 14 deletions(-) - -diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh -index 7bd1044c1..a627302f9 100644 ---- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh -+++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh -@@ -66,7 +66,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true); - -- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false); -+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 128, 1, false); -+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); -@@ -80,7 +81,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); - -- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false); -+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 96, 64, 128, 1, false); -+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); -@@ -89,7 +91,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co - } - - static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) { -- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false); -+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 64, 1, false); -+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false); -@@ -1585,3 +1588,9 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64) - extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); - extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); - extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); -+ -+// GLM 4.7 Flash uses gqa_ratio 4: -+extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4); -+extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); -+extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); -+extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4); -diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh -index 7c4d6fe67..682fb366e 100644 ---- a/ggml/src/ggml-cuda/fattn-tile.cuh -+++ b/ggml/src/ggml-cuda/fattn-tile.cuh -@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv - GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) - -+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 2, 64, 64) -+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) - - return 0; -@@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv - GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) - -+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 2, 32, 64) -+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64) - - return 0; -@@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am - GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128) - -+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 2, 64, 64) -+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64) - -@@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am - GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128) - -+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 4, 64, 64) -+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 4, 64, 64) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64) - -@@ -1187,6 +1195,14 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm - launch_fattn_tile_switch_ncols1(ctx, dst); - return; - } -+ if (use_gqa_opt && gqa_ratio % 8 == 0) { -+ launch_fattn_tile_switch_ncols1(ctx, dst); -+ return; -+ } -+ if (use_gqa_opt && gqa_ratio % 4 == 0) { -+ launch_fattn_tile_switch_ncols1(ctx, dst); -+ return; -+ } - } - - if constexpr (DV <= 256) { -diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu -index 015540666..1693479cb 100644 ---- a/ggml/src/ggml-cuda/fattn.cu -+++ b/ggml/src/ggml-cuda/fattn.cu -@@ -111,7 +111,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst); - break; - case 576: { -- // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels. -+ // For Deepseek/GLM4, go straight to the ncols1 switch to avoid compiling unnecessary kernels. - GGML_ASSERT(V->ne[0] == 512); - float max_bias = 0.0f; - memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); -@@ -121,8 +121,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg - - GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); - const int gqa_ratio = Q->ne[2] / K->ne[2]; -- GGML_ASSERT(gqa_ratio % 16 == 0); -- ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); -+ GGML_ASSERT(gqa_ratio % 4 == 0); -+ if (gqa_ratio % 16 == 0) { -+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); -+ } else { -+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); -+ } - } break; - default: - GGML_ABORT("fatal error"); -@@ -251,7 +255,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const - if (V->ne[0] != 512) { - return BEST_FATTN_KERNEL_NONE; - } -- if (!gqa_opt_applies || gqa_ratio % 16 != 0) { -+ if (!gqa_opt_applies || gqa_ratio % 4 != 0) { - return BEST_FATTN_KERNEL_NONE; - } - break; -diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu -index 2074e954a..517993cb0 100644 ---- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu -+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu -@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4); - DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4); - DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4); - DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4); -+DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4); -diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu -index 24c64cf00..97b19c67a 100644 ---- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu -+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu -@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4); - DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4); - DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4); - DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4); -+DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4); -diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu -index 1ada657f1..989626dfa 100644 ---- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu -+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu -@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4); - DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4); - DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4); - DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4); -+DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); -diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu -index 86d4ffae2..173de7aac 100644 ---- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu -+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu -@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4); - DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4); - DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4); - DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4); -+DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); -diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m -index f24270bb1..7b5ee968c 100644 ---- a/ggml/src/ggml-metal/ggml-metal-device.m -+++ b/ggml/src/ggml-metal/ggml-metal-device.m -@@ -1071,12 +1071,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te - op->src[0]->ne[0] != 112 && - op->src[0]->ne[0] != 128 && - op->src[0]->ne[0] != 192 && -- op->src[0]->ne[0] != 256) { -- return false; -- } -- if (op->src[0]->ne[0] == 576) { -- // DeepSeek sizes -- // TODO: disabled for now, until optmized -+ op->src[0]->ne[0] != 256 && -+ op->src[0]->ne[0] != 576) { - return false; - } - if (op->src[1]->type != op->src[2]->type) { -diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp -index e99c1763f..80864f303 100644 ---- a/ggml/src/ggml-metal/ggml-metal-ops.cpp -+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp -@@ -2456,7 +2456,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { - - // simdgroups per threadgroup (a.k.a. warps) - //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; -- int32_t nsg = 4; -+ int32_t nsg = ne00 >= 512 ? 8 : 4; - - const size_t smem = FATTN_SMEM(nsg); - -diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal -index c98d269d1..d33c16079 100644 ---- a/ggml/src/ggml-metal/ggml-metal.metal -+++ b/ggml/src/ggml-metal/ggml-metal.metal -@@ -6166,6 +6166,7 @@ kernel void kernel_flash_attn_ext( - //case 1: kernel_flash_attn_ext_impl(FWD_ARGS); break; - //case 2: kernel_flash_attn_ext_impl(FWD_ARGS); break; - case 4: kernel_flash_attn_ext_impl(FWD_ARGS); break; -+ case 8: kernel_flash_attn_ext_impl(FWD_ARGS); break; - } - #undef FWD_TMPL - #undef FWD_ARGS diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh index a627302f9..7bd1044c1 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -66,8 +66,7 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 128, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); @@ -81,8 +80,7 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 96, 64, 128, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); @@ -91,8 +89,7 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co } static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) { - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 64, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false); @@ -1588,9 +1585,3 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64) extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); - -// GLM 4.7 Flash uses gqa_ratio 4: -extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4); -extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); -extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); -extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh index 682fb366e..7c4d6fe67 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh @@ -68,8 +68,6 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 2, 64, 64) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) return 0; @@ -124,8 +122,6 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 2, 32, 64) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64) return 0; @@ -187,8 +183,6 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 2, 64, 64) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64) @@ -251,8 +245,6 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 4, 64, 64) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 4, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64) @@ -1195,14 +1187,6 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm launch_fattn_tile_switch_ncols1(ctx, dst); return; } - if (use_gqa_opt && gqa_ratio % 8 == 0) { - launch_fattn_tile_switch_ncols1(ctx, dst); - return; - } - if (use_gqa_opt && gqa_ratio % 4 == 0) { - launch_fattn_tile_switch_ncols1(ctx, dst); - return; - } } if constexpr (DV <= 256) { diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu b/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu index 1693479cb..015540666 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu @@ -111,7 +111,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst); break; case 576: { - // For Deepseek/GLM4, go straight to the ncols1 switch to avoid compiling unnecessary kernels. + // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels. GGML_ASSERT(V->ne[0] == 512); float max_bias = 0.0f; memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); @@ -121,12 +121,8 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; - GGML_ASSERT(gqa_ratio % 4 == 0); - if (gqa_ratio % 16 == 0) { - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); - } else { - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); - } + GGML_ASSERT(gqa_ratio % 16 == 0); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); } break; default: GGML_ABORT("fatal error"); @@ -255,7 +251,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const if (V->ne[0] != 512) { return BEST_FATTN_KERNEL_NONE; } - if (!gqa_opt_applies || gqa_ratio % 4 != 0) { + if (!gqa_opt_applies || gqa_ratio % 16 != 0) { return BEST_FATTN_KERNEL_NONE; } break; diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu index 517993cb0..2074e954a 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu @@ -8,4 +8,3 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4); -DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu index 97b19c67a..24c64cf00 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu @@ -8,4 +8,3 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4); -DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu index 989626dfa..1ada657f1 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu @@ -8,4 +8,3 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4); -DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu index 173de7aac..86d4ffae2 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu @@ -8,4 +8,3 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4); -DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m index 7b5ee968c..f24270bb1 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m @@ -1071,8 +1071,12 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te op->src[0]->ne[0] != 112 && op->src[0]->ne[0] != 128 && op->src[0]->ne[0] != 192 && - op->src[0]->ne[0] != 256 && - op->src[0]->ne[0] != 576) { + op->src[0]->ne[0] != 256) { + return false; + } + if (op->src[0]->ne[0] == 576) { + // DeepSeek sizes + // TODO: disabled for now, until optmized return false; } if (op->src[1]->type != op->src[2]->type) { diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal index 3235a18eb..13c6715ba 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal @@ -8967,7 +8967,6 @@ kernel void kernel_flash_attn_ext( //case 1: kernel_flash_attn_ext_impl(FWD_ARGS); break; //case 2: kernel_flash_attn_ext_impl(FWD_ARGS); break; case 4: kernel_flash_attn_ext_impl(FWD_ARGS); break; - case 8: kernel_flash_attn_ext_impl(FWD_ARGS); break; } #undef FWD_TMPL #undef FWD_ARGS diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp index 80864f303..e99c1763f 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -2456,7 +2456,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { // simdgroups per threadgroup (a.k.a. warps) //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; - int32_t nsg = ne00 >= 512 ? 8 : 4; + int32_t nsg = 4; const size_t smem = FATTN_SMEM(nsg); diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal index d33c16079..c98d269d1 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal @@ -6166,7 +6166,6 @@ kernel void kernel_flash_attn_ext( //case 1: kernel_flash_attn_ext_impl(FWD_ARGS); break; //case 2: kernel_flash_attn_ext_impl(FWD_ARGS); break; case 4: kernel_flash_attn_ext_impl(FWD_ARGS); break; - case 8: kernel_flash_attn_ext_impl(FWD_ARGS); break; } #undef FWD_TMPL #undef FWD_ARGS diff --git a/model/model.go b/model/model.go index b77d7175d..0af16da80 100644 --- a/model/model.go +++ b/model/model.go @@ -39,13 +39,6 @@ type Model interface { Config() config } -// Validator is an optional interface that models can implement to perform -// validation after tensors have been loaded. If validation fails, model -// loading will fail with the returned error. -type Validator interface { - Validate() error -} - // MultimodalProcessor must be implemented by multimodal models. type MultimodalProcessor interface { // EncodeMultimodal processes a single input (such as an image) and @@ -123,13 +116,6 @@ func New(modelPath string, params ml.BackendParams) (Model, error) { base := Base{b: b, config: m.Config()} v := reflect.ValueOf(m) v.Elem().Set(populateFields(base, v.Elem())) - - if validator, ok := m.(Validator); ok { - if err := validator.Validate(); err != nil { - return nil, err - } - } - return m, nil } diff --git a/model/models/glm4moelite/model.go b/model/models/glm4moelite/model.go index 2c55bbfb2..2e51f7d56 100644 --- a/model/models/glm4moelite/model.go +++ b/model/models/glm4moelite/model.go @@ -1,7 +1,6 @@ package glm4moelite import ( - "errors" "math" "github.com/ollama/ollama/fs" @@ -12,8 +11,6 @@ import ( "github.com/ollama/ollama/model/input" ) -var ErrOldModelFormat = errors.New("this model uses a weight format that is no longer supported; please re-download it") - type Options struct { numExpertsUsed int numExperts int @@ -50,9 +47,7 @@ type Attention struct { KVA *nn.Linear `gguf:"attn_kv_a_mqa"` KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"` - - KB *nn.Linear `gguf:"attn_k_b"` - VB *nn.Linear `gguf:"attn_v_b"` + KVB *nn.Linear `gguf:"attn_kv_b"` Output *nn.Linear `gguf:"attn_out,alt:attn_output"` } @@ -83,16 +78,15 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor qRot := opts.applyRotaryPositionEmbeddings(ctx, queryChunks[1], positions) kRot = opts.applyRotaryPositionEmbeddings(ctx, kRot, positions) kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps) + kPass = attn.KVB.Forward(ctx, kPass) - // MLA absorption: absorb K projection into query - qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3) - qPassAbsorb := attn.KB.Forward(ctx, qPass).Permute(ctx, 0, 2, 1, 3) - query = qRot.Concat(ctx, qPassAbsorb, 0) + kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength) + kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim) - kPass = kPass.Reshape(ctx, opts.kvLoraRank, 1, seqLength) - key := kRot.Concat(ctx, kPass, 0) - - attention := nn.AttentionWithVMLA(ctx, query, key, kPass, nil, attn.VB.Weight, opts.kqScale, cache) + kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1)) + query = qRot.Concat(ctx, queryChunks[0], 0) + key := kRot.Concat(ctx, kvChunks[0], 0) + attention := nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache) attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength) return attn.Output.Forward(ctx, attention) @@ -223,12 +217,8 @@ func New(c fs.Config) (model.Model, error) { keyLength := int(c.Uint("attention.key_length")) valueLength := int(c.Uint("attention.value_length")) - kvLoraRank := int(c.Uint("attention.kv_lora_rank")) - qkRopeHeadDim := int(c.Uint("rope.dimension_count")) - // For MLA absorption, the effective key dimension is kvLoraRank + qkRopeHeadDim - mlaKeyLength := kvLoraRank + qkRopeHeadDim - kqScale := 1.0 / math.Sqrt(float64(mlaKeyLength)) + kqScale := 1.0 / math.Sqrt(float64(keyLength)) var pre []string switch c.String("tokenizer.ggml.pre") { @@ -289,15 +279,6 @@ func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil } -func (m *Model) Validate() error { - for _, layer := range m.Layers { - if layer.Attention != nil && (layer.Attention.KB == nil || layer.Attention.VB == nil) { - return ErrOldModelFormat - } - } - return nil -} - func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions)) diff --git a/model/models/glm4moelite/model_test.go b/model/models/glm4moelite/model_test.go deleted file mode 100644 index fbc3b460d..000000000 --- a/model/models/glm4moelite/model_test.go +++ /dev/null @@ -1,73 +0,0 @@ -package glm4moelite - -import ( - "testing" - - "github.com/ollama/ollama/ml/nn" -) - -func TestValidate(t *testing.T) { - tests := []struct { - name string - model *Model - wantErr bool - }{ - { - name: "valid model with KB and VB", - model: &Model{ - Layers: []Layer{ - {Attention: &Attention{KB: &nn.Linear{}, VB: &nn.Linear{}}}, - }, - }, - wantErr: false, - }, - { - name: "missing KB", - model: &Model{ - Layers: []Layer{ - {Attention: &Attention{VB: &nn.Linear{}}}, - }, - }, - wantErr: true, - }, - { - name: "missing VB", - model: &Model{ - Layers: []Layer{ - {Attention: &Attention{KB: &nn.Linear{}}}, - }, - }, - wantErr: true, - }, - { - name: "missing both KB and VB", - model: &Model{ - Layers: []Layer{ - {Attention: &Attention{}}, - }, - }, - wantErr: true, - }, - { - name: "nil Attention is ok", - model: &Model{ - Layers: []Layer{ - {Attention: nil}, - }, - }, - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := tt.model.Validate() - if (err != nil) != tt.wantErr { - t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr) - } - if tt.wantErr && err != ErrOldModelFormat { - t.Errorf("Validate() error = %v, want %v", err, ErrOldModelFormat) - } - }) - } -}