mirror of
https://github.com/ollama/ollama.git
synced 2026-01-29 07:12:03 +03:00
445 lines
11 KiB
Go
445 lines
11 KiB
Go
package lfm2
|
|
|
|
import (
|
|
"testing"
|
|
|
|
"github.com/ollama/ollama/kvcache"
|
|
"github.com/ollama/ollama/ml"
|
|
)
|
|
|
|
// TestHybridCache tests verify the slot management logic of HybridCache.
|
|
// These tests focus on the recurrent state slot allocation, reference counting,
|
|
// and copy-on-write semantics without requiring a full ML backend.
|
|
|
|
// createSlotOnlyCache creates a HybridCache with only the slot management
|
|
// fields initialized. Used to test slot logic in isolation.
|
|
func createSlotOnlyCache(maxSequences int) *HybridCache {
|
|
return &HybridCache{
|
|
hiddenSize: 256,
|
|
dConv: 3,
|
|
maxSequences: maxSequences,
|
|
refCount: make([]int, maxSequences),
|
|
freeSlots: initFreeSlots(maxSequences),
|
|
slotForSeq: make(map[int]int),
|
|
convCtxs: make(map[int]ml.Context),
|
|
convStates: make(map[int]ml.Tensor),
|
|
}
|
|
}
|
|
|
|
func initFreeSlots(n int) []int {
|
|
slots := make([]int, 0, n)
|
|
for i := n - 1; i >= 0; i-- {
|
|
slots = append(slots, i)
|
|
}
|
|
return slots
|
|
}
|
|
|
|
func TestHybridCache_SlotAllocation(t *testing.T) {
|
|
cache := createSlotOnlyCache(4)
|
|
|
|
// Verify initial state
|
|
if len(cache.freeSlots) != 4 {
|
|
t.Errorf("expected 4 free slots, got %d", len(cache.freeSlots))
|
|
}
|
|
|
|
// Allocate all slots
|
|
for range 4 {
|
|
slot, err := cache.allocSlot()
|
|
if err != nil {
|
|
t.Fatalf("allocSlot failed: %v", err)
|
|
}
|
|
cache.refCount[slot] = 1
|
|
}
|
|
|
|
// Should be full now
|
|
if len(cache.freeSlots) != 0 {
|
|
t.Errorf("expected 0 free slots, got %d", len(cache.freeSlots))
|
|
}
|
|
|
|
// Trying to allocate another should fail
|
|
_, err := cache.allocSlot()
|
|
if err != kvcache.ErrKvCacheFull {
|
|
t.Errorf("expected ErrKvCacheFull, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestHybridCache_SlotReuse(t *testing.T) {
|
|
cache := createSlotOnlyCache(4)
|
|
|
|
// Allocate a slot
|
|
slot1, _ := cache.allocSlot()
|
|
cache.refCount[slot1] = 1
|
|
|
|
// Free it
|
|
cache.refCount[slot1] = 0
|
|
cache.freeSlot(slot1)
|
|
|
|
// Allocate again - should get the same slot back (LIFO)
|
|
slot2, _ := cache.allocSlot()
|
|
if slot2 != slot1 {
|
|
t.Errorf("expected slot %d to be reused, got %d", slot1, slot2)
|
|
}
|
|
}
|
|
|
|
func TestHybridCache_SlotRefCounting_ShareSlot(t *testing.T) {
|
|
cache := createSlotOnlyCache(4)
|
|
|
|
// Allocate slot for seq 1
|
|
slot1, _ := cache.allocSlot()
|
|
cache.slotForSeq[1] = slot1
|
|
cache.refCount[slot1] = 1
|
|
|
|
// Simulate sharing slot with seq 2 (copy-on-write style)
|
|
cache.slotForSeq[2] = slot1
|
|
cache.refCount[slot1]++
|
|
|
|
// Should share the same slot
|
|
if cache.slotForSeq[2] != slot1 {
|
|
t.Errorf("expected seq 2 to share slot %d, got %d", slot1, cache.slotForSeq[2])
|
|
}
|
|
|
|
// Ref count should be 2
|
|
if cache.refCount[slot1] != 2 {
|
|
t.Errorf("expected refCount 2, got %d", cache.refCount[slot1])
|
|
}
|
|
}
|
|
|
|
func TestHybridCache_SlotRefCounting_DecRef(t *testing.T) {
|
|
cache := createSlotOnlyCache(4)
|
|
|
|
// Allocate slot for seq 1
|
|
slot1, _ := cache.allocSlot()
|
|
cache.slotForSeq[1] = slot1
|
|
cache.refCount[slot1] = 1
|
|
|
|
// Share with seq 2
|
|
cache.slotForSeq[2] = slot1
|
|
cache.refCount[slot1]++
|
|
|
|
// Unshare seq 2
|
|
cache.refCount[slot1]--
|
|
delete(cache.slotForSeq, 2)
|
|
|
|
// Ref count should be back to 1
|
|
if cache.refCount[slot1] != 1 {
|
|
t.Errorf("expected refCount 1 after unshare, got %d", cache.refCount[slot1])
|
|
}
|
|
|
|
// Seq 2 should no longer have a slot
|
|
if _, ok := cache.slotForSeq[2]; ok {
|
|
t.Error("seq 2 should not have a slot after unshare")
|
|
}
|
|
}
|
|
|
|
func TestHybridCache_SlotFreeWhenUnused(t *testing.T) {
|
|
cache := createSlotOnlyCache(4)
|
|
|
|
initialFreeSlots := len(cache.freeSlots)
|
|
|
|
// Allocate slot for seq 1
|
|
slot1, _ := cache.allocSlot()
|
|
cache.slotForSeq[1] = slot1
|
|
cache.refCount[slot1] = 1
|
|
|
|
// Free the slot when refCount drops to 0
|
|
cache.refCount[slot1]--
|
|
if cache.refCount[slot1] <= 0 {
|
|
cache.refCount[slot1] = 0
|
|
cache.freeSlot(slot1)
|
|
}
|
|
delete(cache.slotForSeq, 1)
|
|
|
|
// Slot should be freed
|
|
if len(cache.freeSlots) != initialFreeSlots {
|
|
t.Errorf("expected %d free slots, got %d", initialFreeSlots, len(cache.freeSlots))
|
|
}
|
|
|
|
// Ref count should be 0
|
|
if cache.refCount[slot1] != 0 {
|
|
t.Errorf("expected refCount 0, got %d", cache.refCount[slot1])
|
|
}
|
|
}
|
|
|
|
func TestHybridCache_SlotOverwrite(t *testing.T) {
|
|
cache := createSlotOnlyCache(4)
|
|
|
|
// Allocate slots for seq 1 and seq 2
|
|
slot1, _ := cache.allocSlot()
|
|
cache.slotForSeq[1] = slot1
|
|
cache.refCount[slot1] = 1
|
|
|
|
slot2, _ := cache.allocSlot()
|
|
cache.slotForSeq[2] = slot2
|
|
cache.refCount[slot2] = 1
|
|
|
|
initialFreeSlots := len(cache.freeSlots)
|
|
|
|
// Simulate overwriting seq 2's slot with slot1 (sharing)
|
|
// First free the old slot
|
|
cache.refCount[slot2]--
|
|
if cache.refCount[slot2] <= 0 {
|
|
cache.refCount[slot2] = 0
|
|
cache.freeSlot(slot2)
|
|
}
|
|
// Then share slot1
|
|
cache.slotForSeq[2] = slot1
|
|
cache.refCount[slot1]++
|
|
|
|
// Seq 2 should now share slot1
|
|
if cache.slotForSeq[2] != slot1 {
|
|
t.Errorf("expected seq 2 to share slot %d, got %d", slot1, cache.slotForSeq[2])
|
|
}
|
|
|
|
// Old slot2 should be freed
|
|
if len(cache.freeSlots) != initialFreeSlots+1 {
|
|
t.Errorf("expected %d free slots, got %d", initialFreeSlots+1, len(cache.freeSlots))
|
|
}
|
|
}
|
|
|
|
func TestHybridCache_BoundsChecking(t *testing.T) {
|
|
cache := createSlotOnlyCache(4)
|
|
|
|
// Test freeing invalid slot (should not panic)
|
|
cache.freeSlot(-1)
|
|
cache.freeSlot(100) // out of bounds
|
|
|
|
// freeSlot does bounds checking, so invalid slots should be ignored
|
|
if len(cache.freeSlots) != 4 {
|
|
t.Errorf("invalid slots should not affect free list, got %d slots", len(cache.freeSlots))
|
|
}
|
|
}
|
|
|
|
func TestHybridCache_MultipleSequences_RefCounting(t *testing.T) {
|
|
cache := createSlotOnlyCache(8)
|
|
|
|
// Allocate slot for seq 1
|
|
slot1, _ := cache.allocSlot()
|
|
cache.slotForSeq[1] = slot1
|
|
cache.refCount[slot1] = 1
|
|
|
|
// Fork to seq 2, 3, 4 (all share slot1)
|
|
for _, seq := range []int{2, 3, 4} {
|
|
cache.slotForSeq[seq] = slot1
|
|
cache.refCount[slot1]++
|
|
}
|
|
|
|
// Ref count should be 4
|
|
if cache.refCount[slot1] != 4 {
|
|
t.Errorf("expected refCount 4, got %d", cache.refCount[slot1])
|
|
}
|
|
|
|
// Remove seq 2, 3
|
|
for _, seq := range []int{2, 3} {
|
|
delete(cache.slotForSeq, seq)
|
|
cache.refCount[slot1]--
|
|
}
|
|
|
|
if cache.refCount[slot1] != 2 {
|
|
t.Errorf("expected refCount 2, got %d", cache.refCount[slot1])
|
|
}
|
|
|
|
// Slot should still be allocated (not in free list)
|
|
found := false
|
|
for _, s := range cache.freeSlots {
|
|
if s == slot1 {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if found {
|
|
t.Error("slot1 should not be in free list yet")
|
|
}
|
|
|
|
// Remove remaining sequences
|
|
for _, seq := range []int{1, 4} {
|
|
delete(cache.slotForSeq, seq)
|
|
cache.refCount[slot1]--
|
|
}
|
|
|
|
if cache.refCount[slot1] != 0 {
|
|
t.Errorf("expected refCount 0, got %d", cache.refCount[slot1])
|
|
}
|
|
}
|
|
|
|
func TestHybridCache_ChainedSharing(t *testing.T) {
|
|
cache := createSlotOnlyCache(8)
|
|
|
|
// Create seq 1
|
|
slot1, _ := cache.allocSlot()
|
|
cache.slotForSeq[1] = slot1
|
|
cache.refCount[slot1] = 1
|
|
|
|
// Share 1 -> 2
|
|
cache.slotForSeq[2] = slot1
|
|
cache.refCount[slot1]++
|
|
|
|
// Share 2 -> 3 (should still share slot1)
|
|
cache.slotForSeq[3] = cache.slotForSeq[2] // which is slot1
|
|
cache.refCount[slot1]++
|
|
|
|
// All should share slot1
|
|
if cache.slotForSeq[1] != slot1 || cache.slotForSeq[2] != slot1 || cache.slotForSeq[3] != slot1 {
|
|
t.Error("all sequences should share slot1")
|
|
}
|
|
|
|
if cache.refCount[slot1] != 3 {
|
|
t.Errorf("expected refCount 3, got %d", cache.refCount[slot1])
|
|
}
|
|
}
|
|
|
|
func TestHybridCache_CacheParameters(t *testing.T) {
|
|
cache := NewHybridCache(nil, 512, 5) // hiddenSize=512, dConv=5
|
|
|
|
if cache.hiddenSize != 512 {
|
|
t.Errorf("expected hiddenSize 512, got %d", cache.hiddenSize)
|
|
}
|
|
if cache.dConv != 5 {
|
|
t.Errorf("expected dConv 5, got %d", cache.dConv)
|
|
}
|
|
}
|
|
|
|
func TestHybridCache_NumSeqs(t *testing.T) {
|
|
cache := createSlotOnlyCache(4)
|
|
|
|
// Initially no sequences
|
|
if cache.numSeqs() != 0 {
|
|
t.Errorf("expected 0 seqs, got %d", cache.numSeqs())
|
|
}
|
|
|
|
// Manually set up current batch state
|
|
cache.curSeqs = []int{1, 2, 3}
|
|
|
|
if cache.numSeqs() != 3 {
|
|
t.Errorf("expected 3 seqs, got %d", cache.numSeqs())
|
|
}
|
|
}
|
|
|
|
func TestHybridCache_SeqTokens(t *testing.T) {
|
|
cache := createSlotOnlyCache(4)
|
|
|
|
// Initially 0
|
|
if cache.seqTokens() != 0 {
|
|
t.Errorf("expected 0 seqTokens, got %d", cache.seqTokens())
|
|
}
|
|
|
|
// Manually set up current batch state
|
|
cache.curSeqTokens = 16
|
|
|
|
if cache.seqTokens() != 16 {
|
|
t.Errorf("expected 16 seqTokens, got %d", cache.seqTokens())
|
|
}
|
|
}
|
|
|
|
// Test that Seqs returns a clone of curSeqs
|
|
func TestHybridCache_Seqs_ReturnsClone(t *testing.T) {
|
|
cache := createSlotOnlyCache(4)
|
|
|
|
cache.curSeqs = []int{1, 2, 3}
|
|
|
|
seqs := cache.Seqs()
|
|
|
|
// Modify returned slice
|
|
seqs[0] = 999
|
|
|
|
// Original should be unchanged
|
|
if cache.curSeqs[0] != 1 {
|
|
t.Error("Seqs should return a clone, not the original slice")
|
|
}
|
|
}
|
|
|
|
func TestHybridCache_IsSupportedForBatch(t *testing.T) {
|
|
cache := createSlotOnlyCache(4)
|
|
|
|
// Initially not supported (no batch set up)
|
|
if cache.IsSupportedForBatch() {
|
|
t.Error("expected IsSupportedForBatch to be false initially")
|
|
}
|
|
|
|
// Set up a valid batch
|
|
cache.curSeqTokens = 1
|
|
cache.curSeqs = []int{1}
|
|
|
|
if !cache.IsSupportedForBatch() {
|
|
t.Error("expected IsSupportedForBatch to be true with valid batch")
|
|
}
|
|
}
|
|
|
|
func TestHybridCache_ZeroConvSlots_EmptyInputs(t *testing.T) {
|
|
cache := createSlotOnlyCache(4)
|
|
|
|
// zeroConvSlots should handle empty slots without panicking
|
|
cache.zeroConvSlots(nil, nil)
|
|
cache.zeroConvSlots(nil, []int{})
|
|
|
|
// zeroConvSlots should handle empty convStates without panicking
|
|
cache.zeroConvSlots(nil, []int{0, 1, 2})
|
|
}
|
|
|
|
func TestHybridCache_SlotRecycling_TracksNewSlots(t *testing.T) {
|
|
cache := createSlotOnlyCache(4)
|
|
|
|
// Allocate slot for seq 1
|
|
slot1, _ := cache.allocSlot()
|
|
cache.slotForSeq[1] = slot1
|
|
cache.refCount[slot1] = 1
|
|
|
|
// Free the slot (simulating sequence removal)
|
|
cache.refCount[slot1]--
|
|
cache.freeSlot(slot1)
|
|
delete(cache.slotForSeq, 1)
|
|
|
|
// Verify slot is in free list
|
|
if len(cache.freeSlots) != 4 {
|
|
t.Errorf("expected 4 free slots after freeing, got %d", len(cache.freeSlots))
|
|
}
|
|
|
|
// Allocate for new seq 2 - should get recycled slot
|
|
slot2, _ := cache.allocSlot()
|
|
if slot2 != slot1 {
|
|
t.Errorf("expected recycled slot %d, got %d", slot1, slot2)
|
|
}
|
|
|
|
// This recycled slot would need zeroing in the real implementation
|
|
// The actual zeroing is tested via integration tests since it requires ML context
|
|
}
|
|
|
|
func TestHybridCache_NewSequence_GetsTrackedForZeroing(t *testing.T) {
|
|
cache := createSlotOnlyCache(4)
|
|
|
|
// Simulate the slot allocation flow from StartForward
|
|
// When a sequence doesn't have a slot, it gets allocated and tracked as "new"
|
|
|
|
newSlots := []int{}
|
|
|
|
// Seq 1 doesn't have a slot - allocate and track
|
|
seq := 1
|
|
if _, ok := cache.slotForSeq[seq]; !ok {
|
|
slot, err := cache.allocSlot()
|
|
if err != nil {
|
|
t.Fatalf("allocSlot failed: %v", err)
|
|
}
|
|
cache.slotForSeq[seq] = slot
|
|
cache.refCount[slot] = 1
|
|
newSlots = append(newSlots, slot)
|
|
}
|
|
|
|
// Verify newSlots contains the allocated slot
|
|
if len(newSlots) != 1 {
|
|
t.Errorf("expected 1 new slot, got %d", len(newSlots))
|
|
}
|
|
|
|
// Seq 1 already has a slot - should NOT be tracked as new
|
|
newSlots2 := []int{}
|
|
if _, ok := cache.slotForSeq[seq]; !ok {
|
|
slot, _ := cache.allocSlot()
|
|
cache.slotForSeq[seq] = slot
|
|
cache.refCount[slot] = 1
|
|
newSlots2 = append(newSlots2, slot)
|
|
}
|
|
|
|
// Verify no new slots for existing sequence
|
|
if len(newSlots2) != 0 {
|
|
t.Errorf("expected 0 new slots for existing sequence, got %d", len(newSlots2))
|
|
}
|
|
}
|