diff --git a/convert/convert.go b/convert/convert.go index b2e6f5e37..df4359224 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -313,6 +313,8 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) { conv = &deepseek2Model{} case "Glm4MoeLiteForCausalLM": conv = &glm4MoeLiteModel{} + case "Lfm2ForCausalLM": + conv = &lfm2Model{} default: return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0]) } diff --git a/convert/convert_lfm2.go b/convert/convert_lfm2.go new file mode 100644 index 000000000..fdae1074c --- /dev/null +++ b/convert/convert_lfm2.go @@ -0,0 +1,100 @@ +package convert + +import ( + "slices" + "strings" + + "github.com/ollama/ollama/fs/ggml" +) + +type lfm2Model struct { + ModelParameters + HiddenSize uint32 `json:"hidden_size"` + NumHiddenLayers uint32 `json:"num_hidden_layers"` + MaxPositionEmbeddings uint32 `json:"max_position_embeddings"` + IntermediateSize uint32 `json:"intermediate_size"` + NumAttentionHeads uint32 `json:"num_attention_heads"` + NumKeyValueHeads uint32 `json:"num_key_value_heads"` + RopeTheta float32 `json:"rope_theta"` + NormEps float32 `json:"norm_eps"` + ConvLCache uint32 `json:"conv_L_cache"` + LayerTypes []string `json:"layer_types"` + TieEmbedding bool `json:"tie_embedding"` +} + +var _ ModelConverter = (*lfm2Model)(nil) + +func (p *lfm2Model) KV(t *Tokenizer) KV { + kv := p.ModelParameters.KV(t) + kv["general.architecture"] = "lfm2" + kv["lfm2.vocab_size"] = p.VocabSize + kv["lfm2.block_count"] = p.NumHiddenLayers + kv["lfm2.embedding_length"] = p.HiddenSize + kv["lfm2.feed_forward_length"] = p.IntermediateSize + kv["lfm2.context_length"] = p.MaxPositionEmbeddings + + // Build per-layer KV head count array based on layer_types + // (0 = shortconv layer, non-zero = attention layer with that many KV heads) + kvHeadCounts := make([]uint32, p.NumHiddenLayers) + for i := range p.NumHiddenLayers { + if int(i) < len(p.LayerTypes) && p.LayerTypes[i] == "full_attention" { + kvHeadCounts[i] = p.NumKeyValueHeads + } + } + + kv["lfm2.attention.head_count"] = p.NumAttentionHeads + kv["lfm2.attention.head_count_kv"] = kvHeadCounts + kv["lfm2.attention.key_length"] = p.HiddenSize / p.NumAttentionHeads + kv["lfm2.attention.value_length"] = p.HiddenSize / p.NumAttentionHeads + kv["lfm2.attention.layer_norm_rms_epsilon"] = p.NormEps + kv["lfm2.rope.freq_base"] = p.RopeTheta + kv["lfm2.shortconv.l_cache"] = p.ConvLCache + + return kv +} + +func (p *lfm2Model) Tensors(ts []Tensor) []*ggml.Tensor { + var out []*ggml.Tensor + + for _, t := range ts { + shape := t.Shape() + + // Squeeze conv weights: [D, 1, K] -> [D, K] + if strings.HasSuffix(t.Name(), "shortconv.conv.weight") { + if len(shape) == 3 && shape[1] == 1 { + shape = []uint64{shape[0], shape[2]} + } + } + + out = append(out, &ggml.Tensor{ + Name: t.Name(), + Kind: t.Kind(), + Shape: slices.Clone(shape), + WriterTo: t, + }) + } + + return out +} + +func (p *lfm2Model) Replacements() []string { + return []string{ + "model.embed_tokens", "token_embd", + "model.embedding_norm", "output_norm", + "model.layers", "blk", + "operator_norm", "attn_norm", + "self_attn.q_proj", "attn_q", + "self_attn.k_proj", "attn_k", + "self_attn.v_proj", "attn_v", + "self_attn.out_proj", "attn_output", + "self_attn.q_layernorm", "attn_q_norm", + "self_attn.k_layernorm", "attn_k_norm", + "conv.conv", "shortconv.conv", + "conv.in_proj", "shortconv.in_proj", + "conv.out_proj", "shortconv.out_proj", + "feed_forward.w1", "ffn_gate", + "feed_forward.w2", "ffn_down", + "feed_forward.w3", "ffn_up", + "ffn_norm", "ffn_norm", + } +} diff --git a/convert/reader.go b/convert/reader.go index 75764f018..a2ac41dc9 100644 --- a/convert/reader.go +++ b/convert/reader.go @@ -40,6 +40,7 @@ const ( func (t tensorBase) Kind() uint32 { if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") || strings.HasSuffix(t.name, ".bias") || + strings.HasSuffix(t.name, ".shortconv.conv.weight") || t.name == "token_types.weight" || t.name == "v.positional_embedding_vlm" || t.name == "v.tile_position_embd.weight" || diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 6db305f77..4f31221f9 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -270,6 +270,7 @@ func (kv KV) OllamaEngineRequired() bool { "qwen3", "qwen3moe", "qwen3vl", "qwen3vlmoe", "glm4moelite", + "lfm2", }, kv.Architecture()) } @@ -859,6 +860,7 @@ func (f GGML) FlashAttention() bool { "gemma3", "glm4moelite", "gptoss", "gpt-oss", + "lfm2", "mistral3", "olmo3", "qwen3", "qwen3moe", diff --git a/ml/backend.go b/ml/backend.go index f287db6af..fa1f32b69 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -162,6 +162,7 @@ type Tensor interface { AvgPool2D(ctx Context, k, s int, p float32) Tensor Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor Conv3D(ctx Context, weight Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) Tensor + SSMConv(ctx Context, kernel Tensor) Tensor IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index ebcc1d86f..138c646bd 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -1641,6 +1641,13 @@ func (t *Tensor) Conv3D(ctx ml.Context, t2 ml.Tensor, c, s0, s1, s2, p0, p1, p2, return tt } +func (t *Tensor) SSMConv(ctx ml.Context, kernel ml.Tensor) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_ssm_conv(ctx.(*Context).ctx, t.t, kernel.(*Tensor).t), + } +} + func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor { return &Tensor{ b: t.b, diff --git a/model/models/lfm2/cache.go b/model/models/lfm2/cache.go new file mode 100644 index 000000000..7e9d35f5f --- /dev/null +++ b/model/models/lfm2/cache.go @@ -0,0 +1,410 @@ +package lfm2 + +import ( + "slices" + + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model/input" +) + +var _ kvcache.Cache = (*HybridCache)(nil) + +// HybridCache stores: +// - a standard causal KV cache for attention layers +// - a per-sequence recurrent conv state for shortconv layers +// +// Conv state shape (per layer, per sequence): [dConv, hiddenSize] where dConv = L_cache - 1. +// Stored internally as a tensor of shape [dConv * hiddenSize, maxSlots]. +type HybridCache struct { + kv *kvcache.Causal + + backend ml.Backend + dtype ml.DType + maxSequences int + + hiddenSize int + dConv int + + // slot mapping for recurrent state + slotForSeq map[int]int + refCount []int + freeSlots []int + + // per-layer conv state buffers (allocated lazily) + convCtxs map[int]ml.Context + convStates map[int]ml.Tensor // [dConv*hiddenSize, maxSlots] + + // current forward batch (derived in StartForward) + curSeqs []int + curSlots []int + curSlotsInput ml.Tensor + curSeqTokens int + + // track if EnsureWritable has been called for this forward pass + writableEnsured bool + // track any error from EnsureWritable to propagate later + writableError error +} + +func NewHybridCache(shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error), hiddenSize, dConv int) *HybridCache { + return &HybridCache{ + kv: kvcache.NewCausalCache(shift), + hiddenSize: hiddenSize, + dConv: dConv, + slotForSeq: make(map[int]int), + convCtxs: make(map[int]ml.Context), + convStates: make(map[int]ml.Tensor), + } +} + +func (c *HybridCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) { + c.backend = backend + c.dtype = dtype + c.maxSequences = maxSequences + + // initialize slot allocator + c.refCount = make([]int, maxSequences) + c.freeSlots = c.freeSlots[:0] + for i := maxSequences - 1; i >= 0; i-- { + c.freeSlots = append(c.freeSlots, i) + } + + c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch) +} + +func (c *HybridCache) Close() { + for _, ctx := range c.convCtxs { + ctx.Close() + } + c.kv.Close() +} + +func (c *HybridCache) SetConfig(config ml.CacheConfig) { + c.kv.SetConfig(config) +} + +func (c *HybridCache) SetLayer(layer int) { + c.kv.SetLayer(layer) +} + +func (c *HybridCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { + return c.kv.Get(ctx) +} + +func (c *HybridCache) Put(ctx ml.Context, key, value ml.Tensor) { + c.kv.Put(ctx, key, value) +} + +func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { + if err := c.kv.StartForward(ctx, batch, reserve); err != nil { + return err + } + + // Derive equal-length sequence layout for shortconv. + // LFM2 shortconv assumes tokens form a [seq_tokens, seqs] grid. + seqCounts := make(map[int]int) + c.curSeqs = c.curSeqs[:0] + for _, s := range batch.Sequences { + if _, ok := seqCounts[s]; !ok { + c.curSeqs = append(c.curSeqs, s) + } + seqCounts[s]++ + } + + if len(c.curSeqs) == 0 { + return nil + } + + nTokens := len(batch.Sequences) + nSeqs := len(c.curSeqs) + want := nTokens / nSeqs + for _, s := range c.curSeqs { + if seqCounts[s] != want { + return kvcache.ErrNotSupported + } + } + + c.curSeqTokens = want + + // When reserving memory for estimation, use fake slot assignments + // without modifying permanent state (slotForSeq, refCount) + if reserve { + c.curSlots = c.curSlots[:0] + slots := make([]int32, nSeqs) + for i := range nSeqs { + c.curSlots = append(c.curSlots, i) + slots[i] = int32(i) + } + c.curSlotsInput = ctx.Input().FromInts(slots, len(slots)) + return nil + } + + // Ensure slots exist for sequences in this batch + c.curSlots = c.curSlots[:0] + var newSlots []int // track newly allocated slots that need zeroing + for _, s := range c.curSeqs { + slot, ok := c.slotForSeq[s] + if !ok { + var err error + slot, err = c.allocSlot() + if err != nil { + return err + } + c.slotForSeq[s] = slot + c.refCount[slot] = 1 + newSlots = append(newSlots, slot) + } + c.curSlots = append(c.curSlots, slot) + } + + // Zero conv state for newly allocated slots to clear stale data from previous sequences + if len(newSlots) > 0 { + c.zeroConvSlots(ctx, newSlots) + } + + // Create a tensor for the current slots + slots := make([]int32, len(c.curSlots)) + for i, v := range c.curSlots { + slots[i] = int32(v) + } + c.curSlotsInput = ctx.Input().FromInts(slots, len(slots)) + + // Reset writable state for new forward pass + c.writableEnsured = false + c.writableError = nil + + return nil +} + +func (c *HybridCache) allocSlot() (int, error) { + if len(c.freeSlots) == 0 { + return 0, kvcache.ErrKvCacheFull + } + slot := c.freeSlots[len(c.freeSlots)-1] + c.freeSlots = c.freeSlots[:len(c.freeSlots)-1] + return slot, nil +} + +func (c *HybridCache) freeSlot(slot int) { + // Bounds check before freeing + if slot >= 0 && slot < c.maxSequences { + c.freeSlots = append(c.freeSlots, slot) + } +} + +// zeroConvSlots zeros the conv state for the given slots across all layers. +// This must be called when recycling slots to prevent stale state from affecting new sequences. +func (c *HybridCache) zeroConvSlots(ctx ml.Context, slots []int) { + if len(slots) == 0 || len(c.convStates) == 0 { + return + } + + // Use input context for creating tensors + inputCtx := ctx.Input() + + // Create slot indices tensor + slotIndices := make([]int32, len(slots)) + for i, s := range slots { + slotIndices[i] = int32(s) + } + slotsTensor := inputCtx.FromInts(slotIndices, len(slotIndices)) + + // Create zero tensor for the slots (SetRows requires F32 source) + zeros := inputCtx.Zeros(ml.DTypeF32, c.dConv*c.hiddenSize, len(slots)) + + // Zero each layer's conv state for these slots + for _, buf := range c.convStates { + ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor)) + } +} + +// EnsureWritable ensures that sequences in the current batch have private (non-shared) conv slots. +// Returns an error if slot allocation fails. +func (c *HybridCache) EnsureWritable(ctx ml.Context) error { + for i, seq := range c.curSeqs { + slot, ok := c.slotForSeq[seq] + if !ok { + continue + } + + // Bounds check + if slot < 0 || slot >= len(c.refCount) { + continue + } + + if c.refCount[slot] <= 1 { + continue + } + + newSlot, err := c.allocSlot() + if err != nil { + return err + } + c.refCount[slot]-- + c.refCount[newSlot] = 1 + c.slotForSeq[seq] = newSlot + c.curSlots[i] = newSlot + + // Copy existing conv state for all initialized layers + for _, buf := range c.convStates { + // buf: [dConv*hiddenSize, maxSlots] + src := buf.Rows(ctx, ctx.Input().FromInts([]int32{int32(slot)}, 1)) + // SetRows requires F32 source + srcF32 := src.Cast(ctx, ml.DTypeF32) + ctx.Forward(buf.SetRows(ctx, srcF32, ctx.Input().FromInts([]int32{int32(newSlot)}, 1))) + } + } + + // Rebuild current slots tensor + slots := make([]int32, len(c.curSlots)) + for i, v := range c.curSlots { + slots[i] = int32(v) + } + c.curSlotsInput = ctx.Input().FromInts(slots, len(slots)) + + return nil +} + +func (c *HybridCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) { + // KV cache shares prefix metadata (no copy) which is correct for prefix reuse. + c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen) + + // For shortconv state we implement copy-on-write: dst shares the same slot as src. + // On the first write to dst, EnsureWritable will create a private slot. + if dstSlot, ok := c.slotForSeq[dstSeq]; ok { + // Bounds check before decrementing + if dstSlot >= 0 && dstSlot < len(c.refCount) { + c.refCount[dstSlot]-- + if c.refCount[dstSlot] <= 0 { + c.refCount[dstSlot] = 0 + c.freeSlot(dstSlot) + } + } + delete(c.slotForSeq, dstSeq) + } + + srcSlot, ok := c.slotForSeq[srcSeq] + if !ok { + // src may not have a slot yet; dst will allocate on demand + return + } + + // Bounds check before incrementing + if srcSlot >= 0 && srcSlot < len(c.refCount) { + c.slotForSeq[dstSeq] = srcSlot + c.refCount[srcSlot]++ + } +} + +func (c *HybridCache) CanResume(seq int, pos int32) bool { + return c.kv.CanResume(seq, pos) +} + +func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error { + if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil { + return err + } + + // For recurrent state, any removal invalidates the state because + // the state at position N depends on all previous positions. + // Drop the slot mapping so it resets on next use. + slot, ok := c.slotForSeq[seq] + if !ok { + return nil + } + + // Bounds check + if slot < 0 || slot >= len(c.refCount) { + delete(c.slotForSeq, seq) + return nil + } + + c.refCount[slot]-- + if c.refCount[slot] <= 0 { + c.refCount[slot] = 0 + c.freeSlot(slot) + } + delete(c.slotForSeq, seq) + + return nil +} + +func (c *HybridCache) slotsTensor() ml.Tensor { + return c.curSlotsInput +} + +func (c *HybridCache) seqTokens() int { + return c.curSeqTokens +} + +func (c *HybridCache) numSeqs() int { + return len(c.curSeqs) +} + +func (c *HybridCache) convBuffer(ctx ml.Context, layer int) ml.Tensor { + if buf, ok := c.convStates[layer]; ok { + return buf + } + + if _, ok := c.convCtxs[layer]; !ok { + c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer) + } + + buf := c.convCtxs[layer].Zeros(c.dtype, c.dConv*c.hiddenSize, c.maxSequences) + c.convStates[layer] = buf + return buf +} + +// ConvState returns the conv state for current batch sequences as shape [dConv, hiddenSize, nSeqs]. +// Returns an error if copy-on-write allocation fails. +func (c *HybridCache) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) { + if !c.writableEnsured { + needsWritable := false + for _, seq := range c.curSeqs { + slot, ok := c.slotForSeq[seq] + if !ok { + continue + } + if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 { + needsWritable = true + break + } + } + + if needsWritable { + if err := c.EnsureWritable(ctx); err != nil { + c.writableError = err + } + } + c.writableEnsured = true + } + + if c.writableError != nil { + return nil, c.writableError + } + + buf := c.convBuffer(ctx, layer) + cur := buf.Rows(ctx, c.slotsTensor()) + return cur.Reshape(ctx, c.dConv, c.hiddenSize, c.numSeqs()), nil +} + +// UpdateConvState writes a new conv state for current batch sequences. +// newState must have shape [dConv, hiddenSize, nSeqs]. +func (c *HybridCache) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) { + buf := c.convBuffer(ctx, layer) + src := newState.Reshape(ctx, c.dConv*c.hiddenSize, c.numSeqs()) + // SetRows requires F32 source + srcF32 := src.Cast(ctx, ml.DTypeF32) + ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor())) +} + +// IsSupportedForBatch returns true if the current batch layout supports shortconv. +func (c *HybridCache) IsSupportedForBatch() bool { + return c.curSeqTokens > 0 && len(c.curSeqs) > 0 +} + +// Seqs returns the ordered unique sequences for the current forward pass. +func (c *HybridCache) Seqs() []int { + return slices.Clone(c.curSeqs) +} diff --git a/model/models/lfm2/cache_test.go b/model/models/lfm2/cache_test.go new file mode 100644 index 000000000..f4c493c20 --- /dev/null +++ b/model/models/lfm2/cache_test.go @@ -0,0 +1,444 @@ +package lfm2 + +import ( + "testing" + + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" +) + +// TestHybridCache tests verify the slot management logic of HybridCache. +// These tests focus on the recurrent state slot allocation, reference counting, +// and copy-on-write semantics without requiring a full ML backend. + +// createSlotOnlyCache creates a HybridCache with only the slot management +// fields initialized. Used to test slot logic in isolation. +func createSlotOnlyCache(maxSequences int) *HybridCache { + return &HybridCache{ + hiddenSize: 256, + dConv: 3, + maxSequences: maxSequences, + refCount: make([]int, maxSequences), + freeSlots: initFreeSlots(maxSequences), + slotForSeq: make(map[int]int), + convCtxs: make(map[int]ml.Context), + convStates: make(map[int]ml.Tensor), + } +} + +func initFreeSlots(n int) []int { + slots := make([]int, 0, n) + for i := n - 1; i >= 0; i-- { + slots = append(slots, i) + } + return slots +} + +func TestHybridCache_SlotAllocation(t *testing.T) { + cache := createSlotOnlyCache(4) + + // Verify initial state + if len(cache.freeSlots) != 4 { + t.Errorf("expected 4 free slots, got %d", len(cache.freeSlots)) + } + + // Allocate all slots + for range 4 { + slot, err := cache.allocSlot() + if err != nil { + t.Fatalf("allocSlot failed: %v", err) + } + cache.refCount[slot] = 1 + } + + // Should be full now + if len(cache.freeSlots) != 0 { + t.Errorf("expected 0 free slots, got %d", len(cache.freeSlots)) + } + + // Trying to allocate another should fail + _, err := cache.allocSlot() + if err != kvcache.ErrKvCacheFull { + t.Errorf("expected ErrKvCacheFull, got %v", err) + } +} + +func TestHybridCache_SlotReuse(t *testing.T) { + cache := createSlotOnlyCache(4) + + // Allocate a slot + slot1, _ := cache.allocSlot() + cache.refCount[slot1] = 1 + + // Free it + cache.refCount[slot1] = 0 + cache.freeSlot(slot1) + + // Allocate again - should get the same slot back (LIFO) + slot2, _ := cache.allocSlot() + if slot2 != slot1 { + t.Errorf("expected slot %d to be reused, got %d", slot1, slot2) + } +} + +func TestHybridCache_SlotRefCounting_ShareSlot(t *testing.T) { + cache := createSlotOnlyCache(4) + + // Allocate slot for seq 1 + slot1, _ := cache.allocSlot() + cache.slotForSeq[1] = slot1 + cache.refCount[slot1] = 1 + + // Simulate sharing slot with seq 2 (copy-on-write style) + cache.slotForSeq[2] = slot1 + cache.refCount[slot1]++ + + // Should share the same slot + if cache.slotForSeq[2] != slot1 { + t.Errorf("expected seq 2 to share slot %d, got %d", slot1, cache.slotForSeq[2]) + } + + // Ref count should be 2 + if cache.refCount[slot1] != 2 { + t.Errorf("expected refCount 2, got %d", cache.refCount[slot1]) + } +} + +func TestHybridCache_SlotRefCounting_DecRef(t *testing.T) { + cache := createSlotOnlyCache(4) + + // Allocate slot for seq 1 + slot1, _ := cache.allocSlot() + cache.slotForSeq[1] = slot1 + cache.refCount[slot1] = 1 + + // Share with seq 2 + cache.slotForSeq[2] = slot1 + cache.refCount[slot1]++ + + // Unshare seq 2 + cache.refCount[slot1]-- + delete(cache.slotForSeq, 2) + + // Ref count should be back to 1 + if cache.refCount[slot1] != 1 { + t.Errorf("expected refCount 1 after unshare, got %d", cache.refCount[slot1]) + } + + // Seq 2 should no longer have a slot + if _, ok := cache.slotForSeq[2]; ok { + t.Error("seq 2 should not have a slot after unshare") + } +} + +func TestHybridCache_SlotFreeWhenUnused(t *testing.T) { + cache := createSlotOnlyCache(4) + + initialFreeSlots := len(cache.freeSlots) + + // Allocate slot for seq 1 + slot1, _ := cache.allocSlot() + cache.slotForSeq[1] = slot1 + cache.refCount[slot1] = 1 + + // Free the slot when refCount drops to 0 + cache.refCount[slot1]-- + if cache.refCount[slot1] <= 0 { + cache.refCount[slot1] = 0 + cache.freeSlot(slot1) + } + delete(cache.slotForSeq, 1) + + // Slot should be freed + if len(cache.freeSlots) != initialFreeSlots { + t.Errorf("expected %d free slots, got %d", initialFreeSlots, len(cache.freeSlots)) + } + + // Ref count should be 0 + if cache.refCount[slot1] != 0 { + t.Errorf("expected refCount 0, got %d", cache.refCount[slot1]) + } +} + +func TestHybridCache_SlotOverwrite(t *testing.T) { + cache := createSlotOnlyCache(4) + + // Allocate slots for seq 1 and seq 2 + slot1, _ := cache.allocSlot() + cache.slotForSeq[1] = slot1 + cache.refCount[slot1] = 1 + + slot2, _ := cache.allocSlot() + cache.slotForSeq[2] = slot2 + cache.refCount[slot2] = 1 + + initialFreeSlots := len(cache.freeSlots) + + // Simulate overwriting seq 2's slot with slot1 (sharing) + // First free the old slot + cache.refCount[slot2]-- + if cache.refCount[slot2] <= 0 { + cache.refCount[slot2] = 0 + cache.freeSlot(slot2) + } + // Then share slot1 + cache.slotForSeq[2] = slot1 + cache.refCount[slot1]++ + + // Seq 2 should now share slot1 + if cache.slotForSeq[2] != slot1 { + t.Errorf("expected seq 2 to share slot %d, got %d", slot1, cache.slotForSeq[2]) + } + + // Old slot2 should be freed + if len(cache.freeSlots) != initialFreeSlots+1 { + t.Errorf("expected %d free slots, got %d", initialFreeSlots+1, len(cache.freeSlots)) + } +} + +func TestHybridCache_BoundsChecking(t *testing.T) { + cache := createSlotOnlyCache(4) + + // Test freeing invalid slot (should not panic) + cache.freeSlot(-1) + cache.freeSlot(100) // out of bounds + + // freeSlot does bounds checking, so invalid slots should be ignored + if len(cache.freeSlots) != 4 { + t.Errorf("invalid slots should not affect free list, got %d slots", len(cache.freeSlots)) + } +} + +func TestHybridCache_MultipleSequences_RefCounting(t *testing.T) { + cache := createSlotOnlyCache(8) + + // Allocate slot for seq 1 + slot1, _ := cache.allocSlot() + cache.slotForSeq[1] = slot1 + cache.refCount[slot1] = 1 + + // Fork to seq 2, 3, 4 (all share slot1) + for _, seq := range []int{2, 3, 4} { + cache.slotForSeq[seq] = slot1 + cache.refCount[slot1]++ + } + + // Ref count should be 4 + if cache.refCount[slot1] != 4 { + t.Errorf("expected refCount 4, got %d", cache.refCount[slot1]) + } + + // Remove seq 2, 3 + for _, seq := range []int{2, 3} { + delete(cache.slotForSeq, seq) + cache.refCount[slot1]-- + } + + if cache.refCount[slot1] != 2 { + t.Errorf("expected refCount 2, got %d", cache.refCount[slot1]) + } + + // Slot should still be allocated (not in free list) + found := false + for _, s := range cache.freeSlots { + if s == slot1 { + found = true + break + } + } + if found { + t.Error("slot1 should not be in free list yet") + } + + // Remove remaining sequences + for _, seq := range []int{1, 4} { + delete(cache.slotForSeq, seq) + cache.refCount[slot1]-- + } + + if cache.refCount[slot1] != 0 { + t.Errorf("expected refCount 0, got %d", cache.refCount[slot1]) + } +} + +func TestHybridCache_ChainedSharing(t *testing.T) { + cache := createSlotOnlyCache(8) + + // Create seq 1 + slot1, _ := cache.allocSlot() + cache.slotForSeq[1] = slot1 + cache.refCount[slot1] = 1 + + // Share 1 -> 2 + cache.slotForSeq[2] = slot1 + cache.refCount[slot1]++ + + // Share 2 -> 3 (should still share slot1) + cache.slotForSeq[3] = cache.slotForSeq[2] // which is slot1 + cache.refCount[slot1]++ + + // All should share slot1 + if cache.slotForSeq[1] != slot1 || cache.slotForSeq[2] != slot1 || cache.slotForSeq[3] != slot1 { + t.Error("all sequences should share slot1") + } + + if cache.refCount[slot1] != 3 { + t.Errorf("expected refCount 3, got %d", cache.refCount[slot1]) + } +} + +func TestHybridCache_CacheParameters(t *testing.T) { + cache := NewHybridCache(nil, 512, 5) // hiddenSize=512, dConv=5 + + if cache.hiddenSize != 512 { + t.Errorf("expected hiddenSize 512, got %d", cache.hiddenSize) + } + if cache.dConv != 5 { + t.Errorf("expected dConv 5, got %d", cache.dConv) + } +} + +func TestHybridCache_NumSeqs(t *testing.T) { + cache := createSlotOnlyCache(4) + + // Initially no sequences + if cache.numSeqs() != 0 { + t.Errorf("expected 0 seqs, got %d", cache.numSeqs()) + } + + // Manually set up current batch state + cache.curSeqs = []int{1, 2, 3} + + if cache.numSeqs() != 3 { + t.Errorf("expected 3 seqs, got %d", cache.numSeqs()) + } +} + +func TestHybridCache_SeqTokens(t *testing.T) { + cache := createSlotOnlyCache(4) + + // Initially 0 + if cache.seqTokens() != 0 { + t.Errorf("expected 0 seqTokens, got %d", cache.seqTokens()) + } + + // Manually set up current batch state + cache.curSeqTokens = 16 + + if cache.seqTokens() != 16 { + t.Errorf("expected 16 seqTokens, got %d", cache.seqTokens()) + } +} + +// Test that Seqs returns a clone of curSeqs +func TestHybridCache_Seqs_ReturnsClone(t *testing.T) { + cache := createSlotOnlyCache(4) + + cache.curSeqs = []int{1, 2, 3} + + seqs := cache.Seqs() + + // Modify returned slice + seqs[0] = 999 + + // Original should be unchanged + if cache.curSeqs[0] != 1 { + t.Error("Seqs should return a clone, not the original slice") + } +} + +func TestHybridCache_IsSupportedForBatch(t *testing.T) { + cache := createSlotOnlyCache(4) + + // Initially not supported (no batch set up) + if cache.IsSupportedForBatch() { + t.Error("expected IsSupportedForBatch to be false initially") + } + + // Set up a valid batch + cache.curSeqTokens = 1 + cache.curSeqs = []int{1} + + if !cache.IsSupportedForBatch() { + t.Error("expected IsSupportedForBatch to be true with valid batch") + } +} + +func TestHybridCache_ZeroConvSlots_EmptyInputs(t *testing.T) { + cache := createSlotOnlyCache(4) + + // zeroConvSlots should handle empty slots without panicking + cache.zeroConvSlots(nil, nil) + cache.zeroConvSlots(nil, []int{}) + + // zeroConvSlots should handle empty convStates without panicking + cache.zeroConvSlots(nil, []int{0, 1, 2}) +} + +func TestHybridCache_SlotRecycling_TracksNewSlots(t *testing.T) { + cache := createSlotOnlyCache(4) + + // Allocate slot for seq 1 + slot1, _ := cache.allocSlot() + cache.slotForSeq[1] = slot1 + cache.refCount[slot1] = 1 + + // Free the slot (simulating sequence removal) + cache.refCount[slot1]-- + cache.freeSlot(slot1) + delete(cache.slotForSeq, 1) + + // Verify slot is in free list + if len(cache.freeSlots) != 4 { + t.Errorf("expected 4 free slots after freeing, got %d", len(cache.freeSlots)) + } + + // Allocate for new seq 2 - should get recycled slot + slot2, _ := cache.allocSlot() + if slot2 != slot1 { + t.Errorf("expected recycled slot %d, got %d", slot1, slot2) + } + + // This recycled slot would need zeroing in the real implementation + // The actual zeroing is tested via integration tests since it requires ML context +} + +func TestHybridCache_NewSequence_GetsTrackedForZeroing(t *testing.T) { + cache := createSlotOnlyCache(4) + + // Simulate the slot allocation flow from StartForward + // When a sequence doesn't have a slot, it gets allocated and tracked as "new" + + newSlots := []int{} + + // Seq 1 doesn't have a slot - allocate and track + seq := 1 + if _, ok := cache.slotForSeq[seq]; !ok { + slot, err := cache.allocSlot() + if err != nil { + t.Fatalf("allocSlot failed: %v", err) + } + cache.slotForSeq[seq] = slot + cache.refCount[slot] = 1 + newSlots = append(newSlots, slot) + } + + // Verify newSlots contains the allocated slot + if len(newSlots) != 1 { + t.Errorf("expected 1 new slot, got %d", len(newSlots)) + } + + // Seq 1 already has a slot - should NOT be tracked as new + newSlots2 := []int{} + if _, ok := cache.slotForSeq[seq]; !ok { + slot, _ := cache.allocSlot() + cache.slotForSeq[seq] = slot + cache.refCount[slot] = 1 + newSlots2 = append(newSlots2, slot) + } + + // Verify no new slots for existing sequence + if len(newSlots2) != 0 { + t.Errorf("expected 0 new slots for existing sequence, got %d", len(newSlots2)) + } +} diff --git a/model/models/lfm2/model.go b/model/models/lfm2/model.go new file mode 100644 index 000000000..8ebeda1b8 --- /dev/null +++ b/model/models/lfm2/model.go @@ -0,0 +1,253 @@ +package lfm2 + +import ( + "cmp" + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/rope" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type Options struct { + hiddenSize int + headDim, ropeDim int + + eps, ropeBase, ropeScale float32 + + ropeType string + originalContextLength int + + // per-layer head counts (LFM2 alternates attention and recurrent layers) + numHeadsByLayer []int + numKVHeadsByLayer []int +} + +func (o Options) headDimValue() int { + // Head dim is shared across layers; fall back to first attention layer head count. + for _, h := range o.numHeadsByLayer { + if h > 0 { + return cmp.Or(o.headDim, o.hiddenSize/h) + } + } + return cmp.Or(o.headDim, o.hiddenSize) +} + +func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor { + opts := []func(*rope.Options){rope.WithTypeNeoX()} + if o.ropeType == "yarn" { + attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale)))) + opts = append(opts, + rope.WithOriginalContextLength(o.originalContextLength), + rope.WithExtrapolationFactor(1.), + rope.WithAttentionFactor(attnFactor), + ) + } + + headCount := 1 + for _, h := range o.numHeadsByLayer { + if h > 0 { + headCount = h + break + } + } + return nn.RoPE(ctx, states, positions, cmp.Or(o.ropeDim, o.headDim, o.hiddenSize/headCount), o.ropeBase, 1./o.ropeScale, opts...) +} + +type Model struct { + model.Base + model.TextProcessor + + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + Layers []Layer `gguf:"blk"` + OutputNorm *nn.RMSNorm `gguf:"output_norm,alt:token_embd_norm"` + Output *nn.Linear `gguf:"output,alt:token_embd"` + + Options +} + +func New(c fs.Config) (model.Model, error) { + if c.Uint("expert_count") > 0 { + return nil, model.ErrUnsupportedModel + } + + if c.String("tokenizer.ggml.model") != "gpt2" { + return nil, model.ErrUnsupportedTokenizer + } + + vocabulary := model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Scores: c.Floats("tokenizer.ggml.scores"), + Types: c.Ints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), + } + + var pretokenizers []string + switch c.String("tokenizer.ggml.pre") { + case "default": + // use default BPE pretokenizer + default: + // llama-bpe style (default for LFM2) + pretokenizers = []string{ + `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, + } + } + + m := Model{ + TextProcessor: model.NewBytePairEncoding(&vocabulary, pretokenizers...), + Layers: make([]Layer, c.Uint("block_count")), + Options: Options{ + hiddenSize: int(c.Uint("embedding_length")), + headDim: int(c.Uint("attention.key_length")), + ropeDim: int(c.Uint("rope.dimension_count")), + eps: c.Float("attention.layer_norm_rms_epsilon"), + ropeType: c.String("rope.scaling.type"), + ropeBase: c.Float("rope.freq_base"), + ropeScale: c.Float("rope.scaling.factor", 1), + originalContextLength: int(c.Uint("rope.scaling.original_context_length")), + }, + } + + type headCounts interface { + HeadCount() []uint64 + HeadCountKV() []uint64 + } + hc, ok := c.(headCounts) + if !ok { + return nil, model.ErrUnsupportedModel + } + + headCount := hc.HeadCount() + headCountKV := hc.HeadCountKV() + + m.numHeadsByLayer = make([]int, len(m.Layers)) + m.numKVHeadsByLayer = make([]int, len(m.Layers)) + for i := range m.Layers { + m.numHeadsByLayer[i] = int(headCount[i]) + m.numKVHeadsByLayer[i] = int(headCountKV[i]) + + if m.numKVHeadsByLayer[i] == 0 { + m.Layers[i].Operator = &ShortConv{} + } else { + m.Layers[i].Operator = &Attention{} + } + } + + lCache := int(c.Uint("shortconv.l_cache")) + dConv := max(0, lCache-1) + m.Cache = NewHybridCache(m.Shift, m.hiddenSize, dConv) + return &m, nil +} + +type Operator interface { + Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, layer int, opts *Options) ml.Tensor +} + +type Attention struct { + Query *nn.Linear `gguf:"attn_q"` + QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"` + Key *nn.Linear `gguf:"attn_k"` + KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_output,alt:attn_out"` +} + +func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, layer int, opts *Options) ml.Tensor { + batchSize := hiddenStates.Dim(1) + headDim := opts.headDimValue() + numHeads := opts.numHeadsByLayer[layer] + numKVHeads := opts.numKVHeadsByLayer[layer] + + query := sa.Query.Forward(ctx, hiddenStates) + key := sa.Key.Forward(ctx, hiddenStates) + value := sa.Value.Forward(ctx, hiddenStates) + + query = query.Reshape(ctx, headDim, numHeads, batchSize) + key = key.Reshape(ctx, headDim, numKVHeads, batchSize) + value = value.Reshape(ctx, headDim, numKVHeads, batchSize) + + query = sa.QueryNorm.Forward(ctx, query, opts.eps) + key = sa.KeyNorm.Forward(ctx, key, opts.eps) + + query = opts.applyRotaryPositionEmbeddings(ctx, query, positions) + key = opts.applyRotaryPositionEmbeddings(ctx, key, positions) + + attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), cache) + attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize) + return sa.Output.Forward(ctx, attention) +} + +type MLP struct { + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` + Gate *nn.Linear `gguf:"ffn_gate"` +} + +func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor { + hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState)) + return mlp.Down.Forward(ctx, hiddenState) +} + +type Layer struct { + AttentionNorm *nn.RMSNorm `gguf:"attn_norm"` + Operator Operator + MLPNorm *nn.RMSNorm `gguf:"ffn_norm"` + MLP *MLP +} + +func (l *Layer) Forward(ctx ml.Context, layer int, hiddenState, positions, outputs ml.Tensor, cache *HybridCache, opts *Options) ml.Tensor { + residual := hiddenState + + hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps) + hiddenState = l.Operator.Forward(ctx, hiddenState, positions, cache, layer, opts) + + if outputs != nil { + hiddenState = hiddenState.Rows(ctx, outputs) + residual = residual.Rows(ctx, outputs) + } + + hiddenState = hiddenState.Add(ctx, residual) + residual = hiddenState + + hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps) + hiddenState = l.MLP.Forward(ctx, hiddenState, opts) + return hiddenState.Add(ctx, residual) +} + +func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil +} + +func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions)) + + hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) + + for i, layer := range m.Layers { + m.Cache.SetLayer(i) + + var outputs ml.Tensor + if i == len(m.Layers)-1 { + outputs = batch.Outputs + } + + hiddenState = layer.Forward(ctx, i, hiddenState, positions, outputs, m.Cache.(*HybridCache), &m.Options) + } + + hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) + return m.Output.Forward(ctx, hiddenState), nil +} + +func init() { + model.Register("lfm2", New) +} diff --git a/model/models/lfm2/shortconv.go b/model/models/lfm2/shortconv.go new file mode 100644 index 000000000..d1f6c15fe --- /dev/null +++ b/model/models/lfm2/shortconv.go @@ -0,0 +1,50 @@ +package lfm2 + +import ( + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" +) + +type shortConvKernel struct { + Weight ml.Tensor `gguf:"weight"` +} + +// ShortConv implements the LFM2 short-convolution block (GGML_OP_SSM_CONV) with a recurrent +// state stored in the HybridCache. +type ShortConv struct { + Conv *shortConvKernel `gguf:"shortconv.conv"` + InProj *nn.Linear `gguf:"shortconv.in_proj"` + OutProj *nn.Linear `gguf:"shortconv.out_proj"` +} + +func (sc *ShortConv) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ ml.Tensor, cache *HybridCache, layer int, opts *Options) ml.Tensor { + nSeqs := cache.numSeqs() + seqTokens := cache.seqTokens() + hiddenSize := hiddenStates.Dim(0) + if nSeqs <= 0 || seqTokens <= 0 || hiddenStates.Dim(1) != nSeqs*seqTokens { + panic("lfm2: unsupported batch layout for shortconv") + } + + bcx := sc.InProj.Forward(ctx, hiddenStates).Reshape(ctx, 3*hiddenSize, seqTokens, nSeqs) + + elementSize := bcx.Stride(0) + b := bcx.View(ctx, 0*hiddenSize*elementSize, hiddenSize, bcx.Stride(1), seqTokens, bcx.Stride(2), nSeqs) + c := bcx.View(ctx, 1*hiddenSize*elementSize, hiddenSize, bcx.Stride(1), seqTokens, bcx.Stride(2), nSeqs) + x := bcx.View(ctx, 2*hiddenSize*elementSize, hiddenSize, bcx.Stride(1), seqTokens, bcx.Stride(2), nSeqs) + + bx := b.Mul(ctx, x).Permute(ctx, 1, 0, 2, 3) + + state, err := cache.ConvState(ctx, layer) + if err != nil { + panic("lfm2: failed to get conv state: " + err.Error()) + } + sx := state.Concat(ctx, bx, 0) + + convOut := sx.SSMConv(ctx, sc.Conv.Weight) + y := c.Mul(ctx, convOut) + + dConv := sx.Dim(0) - seqTokens + cache.UpdateConvState(ctx, layer, sx.Slice(ctx, 0, sx.Dim(0)-dConv, sx.Dim(0), 1)) + + return sc.OutProj.Forward(ctx, y.Reshape(ctx, hiddenSize, seqTokens*nSeqs)) +} diff --git a/model/models/models.go b/model/models/models.go index d900f7cc3..bf5daea7b 100644 --- a/model/models/models.go +++ b/model/models/models.go @@ -9,6 +9,7 @@ import ( _ "github.com/ollama/ollama/model/models/gemma3n" _ "github.com/ollama/ollama/model/models/glm4moelite" _ "github.com/ollama/ollama/model/models/gptoss" + _ "github.com/ollama/ollama/model/models/lfm2" _ "github.com/ollama/ollama/model/models/llama" _ "github.com/ollama/ollama/model/models/llama4" _ "github.com/ollama/ollama/model/models/mistral3" diff --git a/model/parsers/lfm2.go b/model/parsers/lfm2.go new file mode 100644 index 000000000..4aade6926 --- /dev/null +++ b/model/parsers/lfm2.go @@ -0,0 +1,498 @@ +package parsers + +import ( + "encoding/json" + "errors" + "log/slog" + "strconv" + "strings" + "unicode" + + "github.com/ollama/ollama/api" +) + +type LFM2ParserState int + +const ( + LFM2CollectingThinking LFM2ParserState = iota + LFM2CollectingContent + LFM2CollectingToolCalls +) + +const ( + lfm2ThinkingOpenTag = "" + lfm2ThinkingCloseTag = "" + lfm2ToolCallStartTag = "<|tool_call_start|>" + lfm2ToolCallEndTag = "<|tool_call_end|>" +) + +type LFM2Parser struct { + state LFM2ParserState + buffer strings.Builder + hasThinkingSupport bool + needsThinkingLeadingTrim bool // trim leading whitespace after tag + needsContentLeadingTrim bool // trim leading whitespace after tag +} + +func (p *LFM2Parser) HasToolSupport() bool { + return true +} + +func (p *LFM2Parser) HasThinkingSupport() bool { + return p.hasThinkingSupport +} + +func (p *LFM2Parser) setInitialState(lastMessage *api.Message, thinkValue *api.ThinkValue) { + prefill := lastMessage != nil && lastMessage.Role == "assistant" + + // Check both model capability AND request preference + thinkingEnabled := p.HasThinkingSupport() && (thinkValue != nil && thinkValue.Bool()) + + if !thinkingEnabled { + p.state = LFM2CollectingContent + return + } + + if prefill && lastMessage.Content != "" { + p.state = LFM2CollectingContent + return + } + + p.state = LFM2CollectingThinking + p.needsThinkingLeadingTrim = true +} + +func (p *LFM2Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { + p.setInitialState(lastMessage, thinkValue) + return tools +} + +type lfm2Event interface { + isLFM2Event() +} + +type lfm2EventThinkingContent struct { + content string +} + +type lfm2EventContent struct { + content string +} + +type lfm2EventToolCall struct { + toolCall api.ToolCall +} + +func (lfm2EventThinkingContent) isLFM2Event() {} +func (lfm2EventContent) isLFM2Event() {} +func (lfm2EventToolCall) isLFM2Event() {} + +func (p *LFM2Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) { + p.buffer.WriteString(s) + events := p.parseEvents() + + var toolCalls []api.ToolCall + var contentSb strings.Builder + var thinkingSb strings.Builder + for _, event := range events { + switch event := event.(type) { + case lfm2EventToolCall: + toolCalls = append(toolCalls, event.toolCall) + case lfm2EventThinkingContent: + thinkingSb.WriteString(event.content) + case lfm2EventContent: + contentSb.WriteString(event.content) + } + } + + return contentSb.String(), thinkingSb.String(), toolCalls, nil +} + +func (p *LFM2Parser) parseEvents() []lfm2Event { + var all []lfm2Event + + keepLooping := true + for keepLooping { + var events []lfm2Event + events, keepLooping = p.eat() + if len(events) > 0 { + all = append(all, events...) + } + } + + return all +} + +func (p *LFM2Parser) eat() ([]lfm2Event, bool) { + var events []lfm2Event + bufStr := p.buffer.String() + if bufStr == "" { + return events, false + } + + switch p.state { + case LFM2CollectingThinking: + // Strip opening tag if present + if strings.HasPrefix(bufStr, lfm2ThinkingOpenTag) { + bufStr = bufStr[len(lfm2ThinkingOpenTag):] + p.needsThinkingLeadingTrim = true + p.buffer.Reset() + p.buffer.WriteString(bufStr) + } + + // Trim leading whitespace after tag (may span multiple chunks) + if p.needsThinkingLeadingTrim { + if trimmed := strings.TrimLeftFunc(bufStr, unicode.IsSpace); trimmed != bufStr { + bufStr = trimmed + p.buffer.Reset() + p.buffer.WriteString(bufStr) + } + // Clear flag once we have non-whitespace content or buffer is empty + if len(bufStr) > 0 { + p.needsThinkingLeadingTrim = false + } + } + + if strings.Contains(bufStr, lfm2ThinkingCloseTag) { // thinking[] -> content + split := strings.SplitN(bufStr, lfm2ThinkingCloseTag, 2) + thinking := split[0] + thinking = strings.TrimRightFunc(thinking, unicode.IsSpace) + + remaining := split[1] + remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace) + + p.buffer.Reset() + p.buffer.WriteString(remaining) + p.state = LFM2CollectingContent + p.needsThinkingLeadingTrim = false + // Set flag to trim any additional whitespace that may arrive in later chunks + p.needsContentLeadingTrim = len(remaining) == 0 + + if len(thinking) > 0 { + events = append(events, lfm2EventThinkingContent{content: thinking}) + } + return events, true + } else if overlapLen := overlap(bufStr, lfm2ThinkingCloseTag); overlapLen > 0 { // partial + beforePartialTag := bufStr[:len(bufStr)-overlapLen] + trailingLen := trailingWhitespaceLen(beforePartialTag) + ambiguousStart := len(beforePartialTag) - trailingLen + + unambiguous := bufStr[:ambiguousStart] + ambiguous := bufStr[ambiguousStart:] + p.buffer.Reset() + p.buffer.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, lfm2EventThinkingContent{content: unambiguous}) + } + return events, false + } else { // otherwise its thinking content + whitespaceLen := trailingWhitespaceLen(bufStr) + ambiguousStart := len(bufStr) - whitespaceLen + + unambiguous := bufStr[:ambiguousStart] + ambiguous := bufStr[ambiguousStart:] + p.buffer.Reset() + p.buffer.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, lfm2EventThinkingContent{content: unambiguous}) + } + return events, false + } + + case LFM2CollectingContent: + // Trim leading whitespace after tag (may span multiple chunks) + if p.needsContentLeadingTrim { + if trimmed := strings.TrimLeftFunc(bufStr, unicode.IsSpace); trimmed != bufStr { + bufStr = trimmed + p.buffer.Reset() + p.buffer.WriteString(bufStr) + } + // Clear flag once we have non-whitespace content + if len(bufStr) > 0 { + p.needsContentLeadingTrim = false + } + } + + if strings.Contains(bufStr, lfm2ToolCallStartTag) { // content[<|tool_call_start|>] -> tool calls + split := strings.SplitN(bufStr, lfm2ToolCallStartTag, 2) + contentBefore := strings.TrimRightFunc(split[0], unicode.IsSpace) + remaining := split[1] + + p.buffer.Reset() + p.buffer.WriteString(remaining) + p.state = LFM2CollectingToolCalls + + if len(contentBefore) > 0 { + events = append(events, lfm2EventContent{content: contentBefore}) + } + return events, true + } else { // otherwise its content + p.buffer.Reset() + if len(bufStr) > 0 { + events = append(events, lfm2EventContent{content: bufStr}) + } + return events, false + } + + case LFM2CollectingToolCalls: + // Look for complete tool call JSON between tags + if idx := strings.Index(bufStr, lfm2ToolCallEndTag); idx != -1 { + toolCallContent := bufStr[:idx] + + if toolCalls, err := p.parseToolCallsContent(toolCallContent); err == nil && len(toolCalls) > 0 { + remaining := bufStr[idx+len(lfm2ToolCallEndTag):] + + // Check if there's another tool call + if strings.HasPrefix(remaining, lfm2ToolCallStartTag) { + remaining = remaining[len(lfm2ToolCallStartTag):] + } else { + // No more tool calls, go back to content + remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace) + p.state = LFM2CollectingContent + } + + p.buffer.Reset() + p.buffer.WriteString(remaining) + + for _, tc := range toolCalls { + events = append(events, lfm2EventToolCall{toolCall: tc}) + } + return events, true + } else if err != nil { + slog.Warn("lfm2 tool call parsing failed", "error", err, "content", toolCallContent) + } + } + + return events, false + } + + return events, false +} + +// parseToolCallsContent parses one or more tool calls from content +// Supports JSON format and Python-style format including multiple calls: [func1(...),func2(...)] +func (p *LFM2Parser) parseToolCallsContent(content string) ([]api.ToolCall, error) { + content = strings.TrimSpace(content) + + // Try JSON format first: {"name": "func", "arguments": {...}} + var parsed struct { + Name string `json:"name"` + Arguments json.RawMessage `json:"arguments"` + } + + if err := json.Unmarshal([]byte(content), &parsed); err == nil && parsed.Name != "" { + var args api.ToolCallFunctionArguments + if len(parsed.Arguments) > 0 { + if err := json.Unmarshal(parsed.Arguments, &args); err != nil { + return nil, err + } + } else { + args = api.NewToolCallFunctionArguments() + } + + return []api.ToolCall{{ + Function: api.ToolCallFunction{ + Name: parsed.Name, + Arguments: args, + }, + }}, nil + } + + // Try Python-style format: [func(arg1='val1'),func2(arg2='val2')] or func(arg1='val1') + return p.parsePythonStyleToolCalls(content) +} + +// parsePythonStyleToolCalls parses one or more Python-style tool calls +// Examples: [bash(command='ls'),bash(command='pwd')] or bash(command='ls') +func (p *LFM2Parser) parsePythonStyleToolCalls(content string) ([]api.ToolCall, error) { + content = strings.TrimSpace(content) + + // Strip outer brackets if present: [func(...)] -> func(...) + if strings.HasPrefix(content, "[") && strings.HasSuffix(content, "]") { + content = content[1 : len(content)-1] + } + + var toolCalls []api.ToolCall + + // Parse multiple function calls separated by commas at the top level + for len(content) > 0 { + content = strings.TrimSpace(content) + if content == "" { + break + } + + // Skip leading comma from previous iteration + if strings.HasPrefix(content, ",") { + content = strings.TrimSpace(content[1:]) + if content == "" { + break + } + } + + // Find function name + parenIdx := strings.Index(content, "(") + if parenIdx == -1 { + return nil, errors.New("invalid tool call: no opening parenthesis") + } + + funcName := strings.TrimSpace(content[:parenIdx]) + if funcName == "" { + return nil, errors.New("invalid tool call: empty function name") + } + + // Find matching closing parenthesis + closeIdx := findMatchingParen(content, parenIdx) + if closeIdx == -1 { + return nil, errors.New("invalid tool call: no matching closing parenthesis") + } + + argsStr := content[parenIdx+1 : closeIdx] + args := api.NewToolCallFunctionArguments() + + if argsStr != "" { + if err := parsePythonArgs(argsStr, &args); err != nil { + return nil, err + } + } + + toolCalls = append(toolCalls, api.ToolCall{ + Function: api.ToolCallFunction{ + Name: funcName, + Arguments: args, + }, + }) + + // Move past this function call + content = content[closeIdx+1:] + } + + if len(toolCalls) == 0 { + return nil, errors.New("no tool calls found") + } + + return toolCalls, nil +} + +// findMatchingParen finds the index of the closing parenthesis matching the one at openIdx +// Returns -1 if not found. Handles nested parentheses and quoted strings. +func findMatchingParen(s string, openIdx int) int { + depth := 1 + i := openIdx + 1 + for i < len(s) && depth > 0 { + switch s[i] { + case '(': + depth++ + case ')': + depth-- + if depth == 0 { + return i + } + case '\'', '"': + // Skip quoted string + quote := s[i] + i++ + for i < len(s) && s[i] != quote { + if s[i] == '\\' && i+1 < len(s) { + i++ // skip escaped char + } + i++ + } + } + i++ + } + return -1 +} + +// parseToolCallContent parses a single tool call (for backward compatibility with tests) +func (p *LFM2Parser) parseToolCallContent(content string) (api.ToolCall, error) { + calls, err := p.parseToolCallsContent(content) + if err != nil { + return api.ToolCall{}, err + } + if len(calls) == 0 { + return api.ToolCall{}, errors.New("no tool call found") + } + return calls[0], nil +} + +// parsePythonArgs parses Python-style keyword arguments: key='value', key2="value2" +func parsePythonArgs(argsStr string, args *api.ToolCallFunctionArguments) error { + // Simple state machine to parse key='value' pairs + // Handles: command='ls', flag="-la", count=42, enabled=true + var key string + i := 0 + + for i < len(argsStr) { + // Skip whitespace + for i < len(argsStr) && (argsStr[i] == ' ' || argsStr[i] == '\t' || argsStr[i] == '\n') { + i++ + } + if i >= len(argsStr) { + break + } + + // Parse key + keyStart := i + for i < len(argsStr) && argsStr[i] != '=' && argsStr[i] != ',' { + i++ + } + if i >= len(argsStr) || argsStr[i] != '=' { + return errors.New("invalid argument: expected '='") + } + key = strings.TrimSpace(argsStr[keyStart:i]) + i++ // skip '=' + + // Skip whitespace after = + for i < len(argsStr) && (argsStr[i] == ' ' || argsStr[i] == '\t') { + i++ + } + + // Parse value + var value string + if i < len(argsStr) && (argsStr[i] == '\'' || argsStr[i] == '"') { + // Quoted string + quote := argsStr[i] + i++ + valueStart := i + for i < len(argsStr) && argsStr[i] != quote { + if argsStr[i] == '\\' && i+1 < len(argsStr) { + i += 2 // skip escaped char + } else { + i++ + } + } + value = argsStr[valueStart:i] + if i < len(argsStr) { + i++ // skip closing quote + } + args.Set(key, value) + } else { + // Unquoted value (number, bool, etc) + valueStart := i + for i < len(argsStr) && argsStr[i] != ',' { + i++ + } + value = strings.TrimSpace(argsStr[valueStart:i]) + + // Try to parse as number or bool + if v, err := strconv.ParseInt(value, 10, 64); err == nil { + args.Set(key, v) + } else if v, err := strconv.ParseFloat(value, 64); err == nil { + args.Set(key, v) + } else if value == "true" { + args.Set(key, true) + } else if value == "false" { + args.Set(key, false) + } else { + args.Set(key, value) + } + } + + // Skip comma and whitespace + for i < len(argsStr) && (argsStr[i] == ',' || argsStr[i] == ' ' || argsStr[i] == '\t' || argsStr[i] == '\n') { + i++ + } + } + + return nil +} diff --git a/model/parsers/lfm2_test.go b/model/parsers/lfm2_test.go new file mode 100644 index 000000000..3e139b811 --- /dev/null +++ b/model/parsers/lfm2_test.go @@ -0,0 +1,1088 @@ +package parsers + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + + "github.com/ollama/ollama/api" +) + +func TestLFM2Parser(t *testing.T) { + tests := []struct { + name string + input string + expectedContent string + expectedThinking string + expectedCalls []api.ToolCall + hasThinking bool + }{ + { + name: "simple_content", + input: "Hello, how are you?", + expectedContent: "Hello, how are you?", + hasThinking: false, + }, + { + name: "thinking_content", + input: "I need to think about this...The answer is 42.", + expectedThinking: "I need to think about this...", + expectedContent: "The answer is 42.", + hasThinking: true, + }, + { + name: "thinking_with_newlines", + input: "Let me think:\n- Point 1\n- Point 2\n\nHere's my answer.", + expectedThinking: "Let me think:\n- Point 1\n- Point 2", + expectedContent: "Here's my answer.", + hasThinking: true, + }, + { + name: "tool_call_simple", + input: "I'll check the weather.<|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}<|tool_call_end|>", + expectedContent: "I'll check the weather.", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: testArgs(map[string]any{ + "location": "Paris", + }), + }, + }, + }, + hasThinking: false, + }, + { + name: "multiple_tool_calls", + input: "Getting weather for both cities.<|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}<|tool_call_end|><|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"London\"}}<|tool_call_end|>", + expectedContent: "Getting weather for both cities.", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: testArgs(map[string]any{ + "location": "Paris", + }), + }, + }, + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: testArgs(map[string]any{ + "location": "London", + }), + }, + }, + }, + hasThinking: false, + }, + { + name: "complex_tool_arguments", + input: "Processing data.<|tool_call_start|>{\"name\":\"process_data\",\"arguments\":{\"items\":[\"item1\",\"item2\"],\"config\":{\"enabled\":true,\"threshold\":0.95}}}<|tool_call_end|>", + expectedContent: "Processing data.", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "process_data", + Arguments: testArgs(map[string]any{ + "items": []interface{}{"item1", "item2"}, + "config": map[string]interface{}{"enabled": true, "threshold": 0.95}, + }), + }, + }, + }, + hasThinking: false, + }, + { + name: "thinking_with_tool_call", + input: "Let me check the weather...I'll get that for you.<|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}<|tool_call_end|>", + expectedThinking: "Let me check the weather...", + expectedContent: "I'll get that for you.", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: testArgs(map[string]any{ + "location": "Paris", + }), + }, + }, + }, + hasThinking: true, + }, + { + name: "empty_content", + input: "", + expectedContent: "", + hasThinking: false, + }, + { + name: "only_thinking", + input: "Just thinking content", + expectedThinking: "Just thinking content", + expectedContent: "", + hasThinking: true, + }, + { + name: "unicode_content", + input: "مرحبا بالعالم! 你好世界! 🌍", + expectedContent: "مرحبا بالعالم! 你好世界! 🌍", + hasThinking: false, + }, + { + name: "newlines_and_whitespace", + input: "Line 1\n\nLine 3\t\tTabbed content", + expectedContent: "Line 1\n\nLine 3\t\tTabbed content", + hasThinking: false, + }, + { + name: "thinking_with_unicode", + input: "我在思考这个问题...答案是42。", + expectedThinking: "我在思考这个问题...", + expectedContent: "答案是42。", + hasThinking: true, + }, + { + name: "tool_call_with_unicode_args", + input: "Searching for information.<|tool_call_start|>{\"name\":\"search\",\"arguments\":{\"query\":\"北京天气\",\"language\":\"中文\"}}<|tool_call_end|>", + expectedContent: "Searching for information.", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "search", + Arguments: testArgs(map[string]any{ + "query": "北京天气", + "language": "中文", + }), + }, + }, + }, + hasThinking: false, + }, + { + name: "thinking_with_special_chars", + input: "Let me calculate: 2+2=4 & 3*3=9...The results are correct!", + expectedThinking: "Let me calculate: 2+2=4 & 3*3=9...", + expectedContent: "The results are correct!", + hasThinking: true, + }, + { + name: "empty_tool_call_args", + input: "Pinging server.<|tool_call_start|>{\"name\":\"ping\",\"arguments\":{}}<|tool_call_end|>", + expectedContent: "Pinging server.", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "ping", + Arguments: api.NewToolCallFunctionArguments(), + }, + }, + }, + hasThinking: false, + }, + // Python-style tool call tests (from Liquid AI docs) + { + name: "python_style_tool_call", + input: "Let me check that.<|tool_call_start|>[get_candidate_status(candidate_id=\"12345\")]<|tool_call_end|>", + expectedContent: "Let me check that.", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_candidate_status", + Arguments: testArgs(map[string]any{ + "candidate_id": "12345", + }), + }, + }, + }, + hasThinking: false, + }, + { + name: "python_style_multiple_calls", + input: "Running commands.<|tool_call_start|>[bash(command='ls'),bash(command='pwd')]<|tool_call_end|>", + expectedContent: "Running commands.", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "bash", + Arguments: testArgs(map[string]any{ + "command": "ls", + }), + }, + }, + { + Function: api.ToolCallFunction{ + Name: "bash", + Arguments: testArgs(map[string]any{ + "command": "pwd", + }), + }, + }, + }, + hasThinking: false, + }, + { + name: "thinking_then_python_tool_call", + input: "I should check the status...Let me look that up.<|tool_call_start|>[get_status(id=\"123\")]<|tool_call_end|>", + expectedThinking: "I should check the status...", + expectedContent: "Let me look that up.", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_status", + Arguments: testArgs(map[string]any{ + "id": "123", + }), + }, + }, + }, + hasThinking: true, + }, + { + name: "python_style_no_args", + input: "Pinging.<|tool_call_start|>[ping()]<|tool_call_end|>", + expectedContent: "Pinging.", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "ping", + Arguments: api.NewToolCallFunctionArguments(), + }, + }, + }, + hasThinking: false, + }, + { + name: "python_style_mixed_types", + input: "Processing.<|tool_call_start|>[process(name=\"test\", count=42, enabled=true)]<|tool_call_end|>", + expectedContent: "Processing.", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "process", + Arguments: testArgs(map[string]any{ + "name": "test", + "count": int64(42), + "enabled": true, + }), + }, + }, + }, + hasThinking: false, + }, + { + name: "tool_call_only_no_content", + input: "<|tool_call_start|>[check()]<|tool_call_end|>", + expectedContent: "", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "check", + Arguments: api.NewToolCallFunctionArguments(), + }, + }, + }, + hasThinking: false, + }, + { + name: "thinking_directly_to_tool_call", + input: "Let me run this command...<|tool_call_start|>[bash(command='ls')]<|tool_call_end|>", + expectedThinking: "Let me run this command...", + expectedContent: "", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "bash", + Arguments: testArgs(map[string]any{ + "command": "ls", + }), + }, + }, + }, + hasThinking: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := &LFM2Parser{hasThinkingSupport: tt.hasThinking} + parser.Init([]api.Tool{}, nil, &api.ThinkValue{Value: tt.hasThinking}) + + content, thinking, calls, err := parser.Add(tt.input, true) + if err != nil { + t.Fatalf("Add() error = %v", err) + } + + if diff := cmp.Diff(tt.expectedContent, content); diff != "" { + t.Errorf("Content mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(tt.expectedThinking, thinking); diff != "" { + t.Errorf("Thinking mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(tt.expectedCalls, calls, argsComparer); diff != "" { + t.Errorf("Tool calls mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestLFM2Parser_Streaming(t *testing.T) { + tests := []struct { + name string + chunks []string + expectedContent string + expectedThinking string + expectedCalls []api.ToolCall + hasThinking bool + }{ + { + name: "streaming_simple_content", + chunks: []string{"Hello, ", "how are ", "you?"}, + expectedContent: "Hello, how are you?", + hasThinking: false, + }, + { + name: "streaming_thinking", + chunks: []string{"I need to ", "think about this", "...", "The answer is 42."}, + expectedThinking: "I need to think about this...", + expectedContent: "The answer is 42.", + hasThinking: true, + }, + { + name: "streaming_tool_call", + chunks: []string{"I'll check weather.", "<|tool_call_start|>", "{\"name\":\"get_weather\",", "\"arguments\":{\"location\":\"Paris\"}}", "<|tool_call_end|>"}, + expectedContent: "I'll check weather.", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: testArgs(map[string]any{ + "location": "Paris", + }), + }, + }, + }, + hasThinking: false, + }, + { + name: "streaming_thinking_with_partial_tag", + chunks: []string{"Thinking about this", "...", "Done thinking."}, + expectedThinking: "Thinking about this...", + expectedContent: "Done thinking.", + hasThinking: true, + }, + { + name: "streaming_unicode_content", + chunks: []string{"مرحبا ", "بالعالم! ", "你好", "世界!"}, + expectedContent: "مرحبا بالعالم! 你好世界!", + hasThinking: false, + }, + { + name: "streaming_tool_call_with_split_json", + chunks: []string{"Processing.", "<|tool_call_start|>{\"name\":\"calc\",\"arguments\":{\"x\":", "42,\"y\":", "24}}<|tool_call_end|>"}, + expectedContent: "Processing.", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "calc", + Arguments: testArgs(map[string]any{ + "x": float64(42), + "y": float64(24), + }), + }, + }, + }, + hasThinking: false, + }, + { + // Test that leading whitespace after is trimmed even when in separate chunks + name: "streaming_thinking_whitespace_after_tag", + chunks: []string{"", "\n\n ", "Actual thinking content", "", "Response"}, + expectedThinking: "Actual thinking content", + expectedContent: "Response", + hasThinking: true, + }, + { + // Test whitespace between and content in streaming + name: "streaming_whitespace_after_close_tag", + chunks: []string{"Thinking", "\n\n\n", "Response content"}, + expectedThinking: "Thinking", + expectedContent: "Response content", + hasThinking: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := &LFM2Parser{hasThinkingSupport: tt.hasThinking} + parser.Init([]api.Tool{}, nil, &api.ThinkValue{Value: tt.hasThinking}) + + var allContent, allThinking string + var allCalls []api.ToolCall + + for i, chunk := range tt.chunks { + done := i == len(tt.chunks)-1 + content, thinking, calls, err := parser.Add(chunk, done) + if err != nil { + t.Fatalf("Add() error = %v", err) + } + + allContent += content + allThinking += thinking + allCalls = append(allCalls, calls...) + } + + if diff := cmp.Diff(tt.expectedContent, allContent); diff != "" { + t.Errorf("Content mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(tt.expectedThinking, allThinking); diff != "" { + t.Errorf("Thinking mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(tt.expectedCalls, allCalls, argsComparer); diff != "" { + t.Errorf("Tool calls mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestLFM2Parser_HasThinkingSupport(t *testing.T) { + tests := []struct { + name string + hasThinking bool + expectedSupport bool + }{ + { + name: "thinking_enabled", + hasThinking: true, + expectedSupport: true, + }, + { + name: "thinking_disabled", + hasThinking: false, + expectedSupport: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := &LFM2Parser{hasThinkingSupport: tt.hasThinking} + if got := parser.HasThinkingSupport(); got != tt.expectedSupport { + t.Errorf("HasThinkingSupport() = %v, want %v", got, tt.expectedSupport) + } + }) + } +} + +func TestLFM2Parser_HasToolSupport(t *testing.T) { + parser := &LFM2Parser{} + if !parser.HasToolSupport() { + t.Error("HasToolSupport() should return true") + } +} + +func TestLFM2Parser_Init(t *testing.T) { + parser := &LFM2Parser{hasThinkingSupport: true} + tools := []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "test_tool", + }, + }, + } + + returnedTools := parser.Init(tools, nil, &api.ThinkValue{Value: true}) + + if diff := cmp.Diff(tools, returnedTools, toolsComparer); diff != "" { + t.Errorf("Init() returned tools mismatch (-want +got):\n%s", diff) + } + + // Test initial state is set to thinking when enabled + if parser.state != LFM2CollectingThinking { + t.Errorf("Expected initial state to be LFM2CollectingThinking, got %v", parser.state) + } +} + +func TestLFM2Parser_parseToolCallContent(t *testing.T) { + tests := []struct { + name string + content string + expected api.ToolCall + expectError bool + }{ + { + name: "valid_tool_call", + content: `{"name":"get_weather","arguments":{"location":"Paris"}}`, + expected: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: testArgs(map[string]any{ + "location": "Paris", + }), + }, + }, + }, + { + name: "complex_arguments", + content: `{"name":"process_data","arguments":{"items":["a","b"],"config":{"enabled":true}}}`, + expected: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "process_data", + Arguments: testArgs(map[string]any{ + "items": []interface{}{"a", "b"}, + "config": map[string]interface{}{"enabled": true}, + }), + }, + }, + }, + { + name: "empty_arguments", + content: `{"name":"ping","arguments":{}}`, + expected: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "ping", + Arguments: api.NewToolCallFunctionArguments(), + }, + }, + }, + { + name: "unicode_in_tool_name", + content: `{"name":"获取天气","arguments":{"城市":"北京"}}`, + expected: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "获取天气", + Arguments: testArgs(map[string]any{ + "城市": "北京", + }), + }, + }, + }, + { + name: "numeric_arguments", + content: `{"name":"calculate","arguments":{"x":3.14,"y":42,"enabled":true}}`, + expected: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "calculate", + Arguments: testArgs(map[string]any{ + "x": 3.14, + "y": float64(42), + "enabled": true, + }), + }, + }, + }, + { + name: "invalid_json", + content: `{invalid json}`, + expectError: true, + }, + { + name: "missing_name", + content: `{"arguments":{"arg":"value"}}`, + expectError: true, + }, + { + name: "empty_name", + content: `{"name":"","arguments":{"arg":"value"}}`, + expectError: true, + }, + } + + parser := &LFM2Parser{} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := parser.parseToolCallContent(tt.content) + + if tt.expectError { + if err == nil { + t.Error("Expected error but got none") + } + return + } + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if diff := cmp.Diff(tt.expected, result, argsComparer); diff != "" { + t.Errorf("parseToolCallContent() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestLFM2Parser_parseToolCallsContent(t *testing.T) { + tests := []struct { + name string + content string + expected []api.ToolCall + expectError bool + }{ + { + name: "multiple_python_style_calls", + content: `[bash(command='curl google.com'),bash(command='curl example.com')]`, + expected: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "bash", + Arguments: testArgs(map[string]any{ + "command": "curl google.com", + }), + }, + }, + { + Function: api.ToolCallFunction{ + Name: "bash", + Arguments: testArgs(map[string]any{ + "command": "curl example.com", + }), + }, + }, + }, + }, + { + name: "single_python_style_call", + content: `bash(command='ls -la')`, + expected: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "bash", + Arguments: testArgs(map[string]any{ + "command": "ls -la", + }), + }, + }, + }, + }, + { + name: "single_bracketed_call", + content: `[bash(command='pwd')]`, + expected: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "bash", + Arguments: testArgs(map[string]any{ + "command": "pwd", + }), + }, + }, + }, + }, + { + name: "multiple_different_functions", + content: `[get_weather(location='Paris'),search(query='news')]`, + expected: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: testArgs(map[string]any{ + "location": "Paris", + }), + }, + }, + { + Function: api.ToolCallFunction{ + Name: "search", + Arguments: testArgs(map[string]any{ + "query": "news", + }), + }, + }, + }, + }, + { + name: "nested_parentheses_in_arg", + content: `bash(command='echo "(hello)"')`, + expected: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "bash", + Arguments: testArgs(map[string]any{ + "command": `echo "(hello)"`, + }), + }, + }, + }, + }, + { + name: "comma_inside_quotes", + content: `bash(command='echo "hello, world"')`, + expected: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "bash", + Arguments: testArgs(map[string]any{ + "command": `echo "hello, world"`, + }), + }, + }, + }, + }, + { + name: "equals_inside_quotes", + content: `bash(command='export FOO=bar')`, + expected: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "bash", + Arguments: testArgs(map[string]any{ + "command": `export FOO=bar`, + }), + }, + }, + }, + }, + { + name: "double_quotes_with_single_inside", + content: `bash(command="echo 'hello'")`, + expected: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "bash", + Arguments: testArgs(map[string]any{ + "command": `echo 'hello'`, + }), + }, + }, + }, + }, + { + name: "multiple_args", + content: `bash(command='ls', flag='-la', count=42)`, + expected: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "bash", + Arguments: testArgs(map[string]any{ + "command": "ls", + "flag": "-la", + "count": int64(42), + }), + }, + }, + }, + }, + { + name: "no_args", + content: `ping()`, + expected: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "ping", + Arguments: api.NewToolCallFunctionArguments(), + }, + }, + }, + }, + { + name: "three_calls", + content: `[a(x='1'),b(y='2'),c(z='3')]`, + expected: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "a", + Arguments: testArgs(map[string]any{"x": "1"}), + }, + }, + { + Function: api.ToolCallFunction{ + Name: "b", + Arguments: testArgs(map[string]any{"y": "2"}), + }, + }, + { + Function: api.ToolCallFunction{ + Name: "c", + Arguments: testArgs(map[string]any{"z": "3"}), + }, + }, + }, + }, + { + // Note: backslash escapes are preserved as-is, not processed + name: "escaped_quote_in_value", + content: `bash(command='echo \'hello\'')`, + expected: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "bash", + Arguments: testArgs(map[string]any{ + "command": `echo \'hello\'`, + }), + }, + }, + }, + }, + // Tests based on Liquid AI documentation examples + { + name: "docs_example_candidate_status", + content: `[get_candidate_status(candidate_id="12345")]`, + expected: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_candidate_status", + Arguments: testArgs(map[string]any{ + "candidate_id": "12345", + }), + }, + }, + }, + }, + { + name: "boolean_true_arg", + content: `configure(enabled=true)`, + expected: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "configure", + Arguments: testArgs(map[string]any{ + "enabled": true, + }), + }, + }, + }, + }, + { + name: "boolean_false_arg", + content: `configure(enabled=false)`, + expected: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "configure", + Arguments: testArgs(map[string]any{ + "enabled": false, + }), + }, + }, + }, + }, + { + name: "float_arg", + content: `set_threshold(value=0.95)`, + expected: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "set_threshold", + Arguments: testArgs(map[string]any{ + "value": 0.95, + }), + }, + }, + }, + }, + { + name: "negative_number_arg", + content: `adjust(offset=-10)`, + expected: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "adjust", + Arguments: testArgs(map[string]any{ + "offset": int64(-10), + }), + }, + }, + }, + }, + { + name: "mixed_arg_types", + content: `process(name="test", count=42, ratio=3.14, active=true)`, + expected: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "process", + Arguments: testArgs(map[string]any{ + "name": "test", + "count": int64(42), + "ratio": 3.14, + "active": true, + }), + }, + }, + }, + }, + { + name: "newline_in_string_arg", + content: `write_file(content="line1\nline2\nline3")`, + expected: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "write_file", + Arguments: testArgs(map[string]any{ + "content": "line1\\nline2\\nline3", + }), + }, + }, + }, + }, + { + name: "empty_string_arg", + content: `search(query="")`, + expected: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "search", + Arguments: testArgs(map[string]any{ + "query": "", + }), + }, + }, + }, + }, + { + name: "underscore_function_name", + content: `get_user_profile(user_id="abc123")`, + expected: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_user_profile", + Arguments: testArgs(map[string]any{ + "user_id": "abc123", + }), + }, + }, + }, + }, + { + name: "whitespace_around_args", + content: `func( arg1 = "value1" , arg2 = 42 )`, + expected: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "func", + Arguments: testArgs(map[string]any{ + "arg1": "value1", + "arg2": int64(42), + }), + }, + }, + }, + }, + { + name: "json_in_string_arg", + content: `send_data(payload='{"key": "value", "num": 123}')`, + expected: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "send_data", + Arguments: testArgs(map[string]any{ + "payload": `{"key": "value", "num": 123}`, + }), + }, + }, + }, + }, + { + name: "url_in_arg", + content: `fetch(url="https://example.com/api?foo=bar&baz=qux")`, + expected: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "fetch", + Arguments: testArgs(map[string]any{ + "url": "https://example.com/api?foo=bar&baz=qux", + }), + }, + }, + }, + }, + { + name: "path_with_spaces", + content: `read_file(path="/home/user/My Documents/file.txt")`, + expected: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "read_file", + Arguments: testArgs(map[string]any{ + "path": "/home/user/My Documents/file.txt", + }), + }, + }, + }, + }, + } + + parser := &LFM2Parser{} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := parser.parseToolCallsContent(tt.content) + + if tt.expectError { + if err == nil { + t.Error("Expected error but got none") + } + return + } + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if diff := cmp.Diff(tt.expected, result, argsComparer); diff != "" { + t.Errorf("parseToolCallsContent() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestLFM2Parser_EdgeCases(t *testing.T) { + tests := []struct { + name string + input string + expectedContent string + expectedThinking string + hasThinking bool + }{ + { + name: "multiple_think_close_tags", + input: "First thoughtSecond thoughtFinal content", + expectedThinking: "First thought", + expectedContent: "Second thoughtFinal content", + hasThinking: true, + }, + { + name: "empty_thinking_content", + input: "Just content", + expectedThinking: "", + expectedContent: "Just content", + hasThinking: true, + }, + { + name: "thinking_disabled_with_think_tags", + input: "Some contentMore content", + expectedContent: "Some contentMore content", + hasThinking: false, + }, + { + name: "whitespace_only_content", + input: " \n\t ", + expectedContent: " \n\t ", + hasThinking: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := &LFM2Parser{hasThinkingSupport: tt.hasThinking} + parser.Init([]api.Tool{}, nil, &api.ThinkValue{Value: tt.hasThinking}) + + content, thinking, _, err := parser.Add(tt.input, true) + if err != nil { + t.Fatalf("Add() error = %v", err) + } + + if diff := cmp.Diff(tt.expectedContent, content); diff != "" { + t.Errorf("Content mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(tt.expectedThinking, thinking); diff != "" { + t.Errorf("Thinking mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/model/parsers/parsers.go b/model/parsers/parsers.go index 3a3261a04..c5baabe53 100644 --- a/model/parsers/parsers.go +++ b/model/parsers/parsers.go @@ -70,6 +70,10 @@ func ParserForName(name string) Parser { return &FunctionGemmaParser{} case "glm-4.7": return &GLM47Parser{} + case "lfm2": + return &LFM2Parser{hasThinkingSupport: false} + case "lfm2-thinking": + return &LFM2Parser{hasThinkingSupport: true} default: return nil } diff --git a/model/renderers/lfm2.go b/model/renderers/lfm2.go new file mode 100644 index 000000000..5c046835f --- /dev/null +++ b/model/renderers/lfm2.go @@ -0,0 +1,144 @@ +package renderers + +import ( + "encoding/json" + "strings" + + "github.com/ollama/ollama/api" +) + +type LFM2Renderer struct { + IsThinking bool +} + +func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) { + var sb strings.Builder + + // Note: BOS token is added by the tokenizer (add_bos_token: true), not the renderer + + // Extract first system message if present (to combine with tools) + var firstSystemContent string + startIdx := 0 + if len(messages) > 0 && messages[0].Role == "system" { + firstSystemContent = messages[0].Content + startIdx = 1 + } + + // Append tools to first system content + if len(tools) > 0 { + if firstSystemContent != "" { + firstSystemContent += "\n" + } + firstSystemContent += "List of tools: [" + for i, tool := range tools { + toolJSON, err := json.Marshal(tool) + if err != nil { + return "", err + } + firstSystemContent += string(toolJSON) + if i < len(tools)-1 { + firstSystemContent += ", " + } + } + firstSystemContent += "]" + } + + // Output first system block if it has content + if firstSystemContent != "" { + sb.WriteString("<|im_start|>system\n") + sb.WriteString(firstSystemContent) + sb.WriteString("<|im_end|>\n") + } + + // Find the index of the last assistant message for thinking stripping + lastAssistantIndex := -1 + for i := len(messages) - 1; i >= startIdx; i-- { + if messages[i].Role == "assistant" { + lastAssistantIndex = i + break + } + } + + // Track whether we need to add generation prompt + needsGenerationPrompt := len(messages) > 0 + + for i := startIdx; i < len(messages); i++ { + message := messages[i] + switch message.Role { + case "system": + // Additional system messages (after the first) are rendered normally + sb.WriteString("<|im_start|>system\n") + sb.WriteString(message.Content) + sb.WriteString("<|im_end|>\n") + + case "user": + sb.WriteString("<|im_start|>user\n") + sb.WriteString(message.Content) + sb.WriteString("<|im_end|>\n") + needsGenerationPrompt = true + + case "assistant": + sb.WriteString("<|im_start|>assistant\n") + + // Check if this is the last assistant message + isLastAssistant := i == lastAssistantIndex + + // Process content (may need thinking stripped) + content := message.Content + + // Handle thinking tags in assistant content + keepPastThinking := r.IsThinking && (thinkValue != nil && thinkValue.Bool()) + if strings.Contains(content, "") { + parts := strings.SplitN(content, "", 2) + if len(parts) > 1 { + if !isLastAssistant && !keepPastThinking { + // Strip thinking entirely for past assistant messages + content = strings.TrimSpace(parts[1]) + } else { + // Preserve thinking but trim whitespace after + content = parts[0] + "" + strings.TrimLeft(parts[1], " \t\n\r") + } + } + } + + if len(message.ToolCalls) > 0 { + // Assistant with tool calls - write content first (if any after stripping) + if content != "" { + sb.WriteString(content) + } + + for _, toolCall := range message.ToolCalls { + sb.WriteString("<|tool_call_start|>") + toolCallJSON := map[string]any{ + "name": toolCall.Function.Name, + "arguments": toolCall.Function.Arguments, + } + callJSON, _ := json.Marshal(toolCallJSON) + sb.WriteString(string(callJSON)) + sb.WriteString("<|tool_call_end|>") + } + } else { + sb.WriteString(content) + } + + sb.WriteString("<|im_end|>\n") + needsGenerationPrompt = true // Always add gen prompt after assistant when add_generation_prompt=true + + case "tool": + // Tool responses are rendered as plain messages per the chat template + sb.WriteString("<|im_start|>tool\n") + sb.WriteString(message.Content) + sb.WriteString("<|im_end|>\n") + needsGenerationPrompt = true + } + } + + // Add generation prompt + if needsGenerationPrompt { + sb.WriteString("<|im_start|>assistant\n") + // Note: Model is a "thinking-only" model - it will output itself + // We don't add tag to the prompt + } + + return sb.String(), nil +} diff --git a/model/renderers/lfm2_test.go b/model/renderers/lfm2_test.go new file mode 100644 index 000000000..9eb07eea3 --- /dev/null +++ b/model/renderers/lfm2_test.go @@ -0,0 +1,427 @@ +package renderers + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + + "github.com/ollama/ollama/api" +) + +func TestLFM2Renderer(t *testing.T) { + tests := []struct { + name string + messages []api.Message + tools []api.Tool + thinkValue *api.ThinkValue + expected string + }{ + { + name: "basic user message", + messages: []api.Message{ + {Role: "user", Content: "Hello!"}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: "<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n", + }, + { + name: "basic with system message", + messages: []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Hello!"}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n", + }, + { + name: "multiple system messages rendered separately", + messages: []api.Message{ + {Role: "system", Content: "First instruction."}, + {Role: "system", Content: "Second instruction."}, + {Role: "user", Content: "Hello!"}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: "<|im_start|>system\nFirst instruction.<|im_end|>\n<|im_start|>system\nSecond instruction.<|im_end|>\n<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n", + }, + { + name: "multi-turn conversation", + messages: []api.Message{ + {Role: "user", Content: "What is 2+2?"}, + {Role: "assistant", Content: "The answer is 4."}, + {Role: "user", Content: "Thanks!"}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: "<|im_start|>user\nWhat is 2+2?<|im_end|>\n<|im_start|>assistant\nThe answer is 4.<|im_end|>\n<|im_start|>user\nThanks!<|im_end|>\n<|im_start|>assistant\n", + }, + { + name: "only system message", + messages: []api.Message{ + {Role: "system", Content: "You are helpful."}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: "<|im_start|>system\nYou are helpful.<|im_end|>\n<|im_start|>assistant\n", + }, + { + // When assistant is the LAST assistant, thinking is preserved (even with keep_past_thinking=false) + name: "user-assistant-user: last assistant preserves thinking", + messages: []api.Message{ + {Role: "user", Content: "Q1"}, + {Role: "assistant", Content: "reasoningA1"}, + {Role: "user", Content: "Q2"}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\nreasoningA1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n", + }, + { + // With two assistants, first is stripped (not last), second preserved (is last) + name: "multi-turn thinking: first stripped, second preserved", + messages: []api.Message{ + {Role: "user", Content: "Q1"}, + {Role: "assistant", Content: "reason1A1"}, + {Role: "user", Content: "Q2"}, + {Role: "assistant", Content: "reason2A2"}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\nA1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\nreason2A2<|im_end|>\n<|im_start|>assistant\n", + }, + { + // With thinking enabled (keep_past_thinking=true), both preserved + name: "multi-turn thinking: both preserved when thinking enabled", + messages: []api.Message{ + {Role: "user", Content: "Q1"}, + {Role: "assistant", Content: "reason1A1"}, + {Role: "user", Content: "Q2"}, + {Role: "assistant", Content: "reason2A2"}, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\nreason1A1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\nreason2A2<|im_end|>\n<|im_start|>assistant\n", + }, + { + name: "assistant with tool calls", + messages: []api.Message{ + {Role: "user", Content: "What's the weather?"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: testArgs(map[string]any{ + "location": "Paris", + }), + }, + }, + }, + }, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n", + }, + { + name: "assistant with content and tool calls", + messages: []api.Message{ + {Role: "user", Content: "What's the weather in Paris?"}, + { + Role: "assistant", + Content: "Let me check.", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: testArgs(map[string]any{ + "location": "Paris", + }), + }, + }, + }, + }, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: `<|im_start|>user` + "\n" + `What's the weather in Paris?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `Let me check.<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n", + }, + { + name: "tool response", + messages: []api.Message{ + {Role: "user", Content: "What's the weather?"}, + {Role: "assistant", Content: "Let me check."}, + {Role: "tool", Content: "22C, Sunny"}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nLet me check.<|im_end|>\n<|im_start|>tool\n22C, Sunny<|im_end|>\n<|im_start|>assistant\n", + }, + { + name: "multiple tool calls", + messages: []api.Message{ + {Role: "user", Content: "Get weather for Paris and London"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: testArgs(map[string]any{ + "location": "Paris", + }), + }, + }, + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: testArgs(map[string]any{ + "location": "London", + }), + }, + }, + }, + }, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: `<|im_start|>user` + "\n" + `Get weather for Paris and London<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|tool_call_start|>{"arguments":{"location":"London"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n", + }, + { + name: "tools definitions with system message", + messages: []api.Message{ + {Role: "system", Content: "You are helpful."}, + {Role: "user", Content: "What's the weather?"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get current weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: testPropsMap(map[string]api.ToolProperty{ + "location": { + Type: api.PropertyType{"string"}, + Description: "City name", + }, + }), + Required: []string{"location"}, + }, + }, + }, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: `<|im_start|>system` + "\n" + `You are helpful.` + "\n" + `List of tools: [{"type":"function","function":{"name":"get_weather","description":"Get current weather","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}}}}]<|im_end|>` + "\n" + `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n", + }, + { + name: "tools definitions without system message", + messages: []api.Message{ + {Role: "user", Content: "What's the weather?"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get current weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: testPropsMap(map[string]api.ToolProperty{ + "location": { + Type: api.PropertyType{"string"}, + Description: "City name", + }, + }), + Required: []string{"location"}, + }, + }, + }, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: `<|im_start|>system` + "\n" + `List of tools: [{"type":"function","function":{"name":"get_weather","description":"Get current weather","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}}}}]<|im_end|>` + "\n" + `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n", + }, + { + name: "multiple tools without system message", + messages: []api.Message{ + {Role: "user", Content: "Hello"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get weather", + }, + }, + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_time", + Description: "Get time", + }, + }, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: "<|im_start|>system\nList of tools: [{\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"description\":\"Get weather\",\"parameters\":{\"type\":\"\",\"properties\":null}}}, {\"type\":\"function\",\"function\":{\"name\":\"get_time\",\"description\":\"Get time\",\"parameters\":{\"type\":\"\",\"properties\":null}}}]<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n", + }, + { + name: "user-tool sequence", + messages: []api.Message{ + {Role: "user", Content: "Check weather"}, + {Role: "tool", Content: "22C"}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: "<|im_start|>user\nCheck weather<|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n", + }, + { + name: "full tool call cycle", + messages: []api.Message{ + {Role: "user", Content: "Check weather"}, + {Role: "assistant", Content: "Let me check"}, + {Role: "tool", Content: "22C"}, + {Role: "assistant", Content: "It's 22C"}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: "<|im_start|>user\nCheck weather<|im_end|>\n<|im_start|>assistant\nLet me check<|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\nIt's 22C<|im_end|>\n<|im_start|>assistant\n", + }, + { + name: "unicode content", + messages: []api.Message{ + {Role: "user", Content: "你好世界! مرحبا 🌍"}, + {Role: "assistant", Content: "Hello! 👋"}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: "<|im_start|>user\n你好世界! مرحبا 🌍<|im_end|>\n<|im_start|>assistant\nHello! 👋<|im_end|>\n<|im_start|>assistant\n", + }, + { + name: "newlines in content", + messages: []api.Message{ + {Role: "user", Content: "Line 1\nLine 2\n\nLine 4"}, + {Role: "assistant", Content: "Response with\nmultiple\nlines"}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: "<|im_start|>user\nLine 1\nLine 2\n\nLine 4<|im_end|>\n<|im_start|>assistant\nResponse with\nmultiple\nlines<|im_end|>\n<|im_start|>assistant\n", + }, + { + name: "empty assistant content", + messages: []api.Message{ + {Role: "user", Content: "Hello"}, + {Role: "assistant", Content: ""}, + {Role: "user", Content: "OK"}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n<|im_end|>\n<|im_start|>user\nOK<|im_end|>\n<|im_start|>assistant\n", + }, + { + // Generation prompt does NOT include - model outputs it + name: "generation prompt has no think tag", + messages: []api.Message{ + {Role: "user", Content: "Think hard"}, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: "<|im_start|>user\nThink hard<|im_end|>\n<|im_start|>assistant\n", + }, + { + // Interleaved: thinking before tool call - last assistant preserves thinking + name: "thinking before tool call (last assistant)", + messages: []api.Message{ + {Role: "user", Content: "What's the weather?"}, + { + Role: "assistant", + Content: "I need to check the weather", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: testArgs(map[string]any{ + "location": "Paris", + }), + }, + }, + }, + }, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nI need to check the weather<|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>assistant\n", + }, + { + // Two assistants with tool calls - first has thinking stripped + name: "two assistants with tools: first thinking stripped", + messages: []api.Message{ + {Role: "user", Content: "What's the weather?"}, + { + Role: "assistant", + Content: "checking", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: testArgs(map[string]any{ + "location": "Paris", + }), + }, + }, + }, + }, + {Role: "tool", Content: "22C"}, + {Role: "assistant", Content: "got resultIt's 22C!"}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\ngot resultIt's 22C!<|im_end|>\n<|im_start|>assistant\n", + }, + { + // Two assistants with tools - both preserved when thinking enabled + name: "two assistants with tools: both preserved when thinking enabled", + messages: []api.Message{ + {Role: "user", Content: "What's the weather?"}, + { + Role: "assistant", + Content: "checking", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: testArgs(map[string]any{ + "location": "Paris", + }), + }, + }, + }, + }, + {Role: "tool", Content: "22C"}, + {Role: "assistant", Content: "got resultIt's 22C!"}, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nchecking<|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\ngot resultIt's 22C!<|im_end|>\n<|im_start|>assistant\n", + }, + { + // Content before thinking before tool call + name: "content then thinking then tool call", + messages: []api.Message{ + {Role: "user", Content: "What's the weather?"}, + { + Role: "assistant", + Content: "Let me check.Using weather API", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: testArgs(map[string]any{ + "location": "Paris", + }), + }, + }, + }, + }, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nLet me check.Using weather API<|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>assistant\n", + }, + } + + renderer := &LFM2Renderer{IsThinking: true} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue) + if err != nil { + t.Fatalf("Render() error = %v", err) + } + if diff := cmp.Diff(tt.expected, rendered); diff != "" { + t.Errorf("Render() mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/model/renderers/renderer.go b/model/renderers/renderer.go index dbb63b07c..efb966aad 100644 --- a/model/renderers/renderer.go +++ b/model/renderers/renderer.go @@ -82,6 +82,10 @@ func rendererForName(name string) Renderer { return &FunctionGemmaRenderer{} case "glm-4.7": return &GLM47Renderer{} + case "lfm2": + return &LFM2Renderer{IsThinking: false} + case "lfm2-thinking": + return &LFM2Renderer{IsThinking: true} default: return nil } diff --git a/server/quantization.go b/server/quantization.go index b15451d7e..6ecaf583c 100644 --- a/server/quantization.go +++ b/server/quantization.go @@ -198,8 +198,8 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil name := t.Name quantize := strings.HasSuffix(name, "weight") - // don't quantize vision stuff - quantize = quantize && (!strings.Contains(name, "v.") || strings.Contains(name, "_v.")) + // don't quantize vision encoder tensors (named with "v." prefix) + quantize = quantize && !strings.HasPrefix(name, "v.") quantize = quantize && !strings.Contains(name, "mm.") // quantize only 2D and 3D tensors (experts) @@ -219,6 +219,9 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil // NOTE: can't use LLM_TN here because the layer number is not known quantize = quantize && !strings.Contains(name, "ssm_conv1d.weight") + // do not quantize LFM2's shortconv kernel weights + quantize = quantize && !strings.Contains(name, "shortconv.conv.weight") + // do not quantize RWKV's time_mix_first tensors quantize = quantize && !strings.Contains(name, "time_mix_first.weight") quantize = quantize && !strings.Contains(name, "time_mix_w1.weight")