mirror of
https://github.com/ollama/ollama.git
synced 2026-01-29 07:12:03 +03:00
411 lines
10 KiB
Go
411 lines
10 KiB
Go
package lfm2
|
|
|
|
import (
|
|
"slices"
|
|
|
|
"github.com/ollama/ollama/kvcache"
|
|
"github.com/ollama/ollama/ml"
|
|
"github.com/ollama/ollama/model/input"
|
|
)
|
|
|
|
var _ kvcache.Cache = (*HybridCache)(nil)
|
|
|
|
// HybridCache stores:
|
|
// - a standard causal KV cache for attention layers
|
|
// - a per-sequence recurrent conv state for shortconv layers
|
|
//
|
|
// Conv state shape (per layer, per sequence): [dConv, hiddenSize] where dConv = L_cache - 1.
|
|
// Stored internally as a tensor of shape [dConv * hiddenSize, maxSlots].
|
|
type HybridCache struct {
|
|
kv *kvcache.Causal
|
|
|
|
backend ml.Backend
|
|
dtype ml.DType
|
|
maxSequences int
|
|
|
|
hiddenSize int
|
|
dConv int
|
|
|
|
// slot mapping for recurrent state
|
|
slotForSeq map[int]int
|
|
refCount []int
|
|
freeSlots []int
|
|
|
|
// per-layer conv state buffers (allocated lazily)
|
|
convCtxs map[int]ml.Context
|
|
convStates map[int]ml.Tensor // [dConv*hiddenSize, maxSlots]
|
|
|
|
// current forward batch (derived in StartForward)
|
|
curSeqs []int
|
|
curSlots []int
|
|
curSlotsInput ml.Tensor
|
|
curSeqTokens int
|
|
|
|
// track if EnsureWritable has been called for this forward pass
|
|
writableEnsured bool
|
|
// track any error from EnsureWritable to propagate later
|
|
writableError error
|
|
}
|
|
|
|
func NewHybridCache(shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error), hiddenSize, dConv int) *HybridCache {
|
|
return &HybridCache{
|
|
kv: kvcache.NewCausalCache(shift),
|
|
hiddenSize: hiddenSize,
|
|
dConv: dConv,
|
|
slotForSeq: make(map[int]int),
|
|
convCtxs: make(map[int]ml.Context),
|
|
convStates: make(map[int]ml.Tensor),
|
|
}
|
|
}
|
|
|
|
func (c *HybridCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
|
c.backend = backend
|
|
c.dtype = dtype
|
|
c.maxSequences = maxSequences
|
|
|
|
// initialize slot allocator
|
|
c.refCount = make([]int, maxSequences)
|
|
c.freeSlots = c.freeSlots[:0]
|
|
for i := maxSequences - 1; i >= 0; i-- {
|
|
c.freeSlots = append(c.freeSlots, i)
|
|
}
|
|
|
|
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
|
}
|
|
|
|
func (c *HybridCache) Close() {
|
|
for _, ctx := range c.convCtxs {
|
|
ctx.Close()
|
|
}
|
|
c.kv.Close()
|
|
}
|
|
|
|
func (c *HybridCache) SetConfig(config ml.CacheConfig) {
|
|
c.kv.SetConfig(config)
|
|
}
|
|
|
|
func (c *HybridCache) SetLayer(layer int) {
|
|
c.kv.SetLayer(layer)
|
|
}
|
|
|
|
func (c *HybridCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
|
return c.kv.Get(ctx)
|
|
}
|
|
|
|
func (c *HybridCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
|
c.kv.Put(ctx, key, value)
|
|
}
|
|
|
|
func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
|
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Derive equal-length sequence layout for shortconv.
|
|
// LFM2 shortconv assumes tokens form a [seq_tokens, seqs] grid.
|
|
seqCounts := make(map[int]int)
|
|
c.curSeqs = c.curSeqs[:0]
|
|
for _, s := range batch.Sequences {
|
|
if _, ok := seqCounts[s]; !ok {
|
|
c.curSeqs = append(c.curSeqs, s)
|
|
}
|
|
seqCounts[s]++
|
|
}
|
|
|
|
if len(c.curSeqs) == 0 {
|
|
return nil
|
|
}
|
|
|
|
nTokens := len(batch.Sequences)
|
|
nSeqs := len(c.curSeqs)
|
|
want := nTokens / nSeqs
|
|
for _, s := range c.curSeqs {
|
|
if seqCounts[s] != want {
|
|
return kvcache.ErrNotSupported
|
|
}
|
|
}
|
|
|
|
c.curSeqTokens = want
|
|
|
|
// When reserving memory for estimation, use fake slot assignments
|
|
// without modifying permanent state (slotForSeq, refCount)
|
|
if reserve {
|
|
c.curSlots = c.curSlots[:0]
|
|
slots := make([]int32, nSeqs)
|
|
for i := range nSeqs {
|
|
c.curSlots = append(c.curSlots, i)
|
|
slots[i] = int32(i)
|
|
}
|
|
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
|
return nil
|
|
}
|
|
|
|
// Ensure slots exist for sequences in this batch
|
|
c.curSlots = c.curSlots[:0]
|
|
var newSlots []int // track newly allocated slots that need zeroing
|
|
for _, s := range c.curSeqs {
|
|
slot, ok := c.slotForSeq[s]
|
|
if !ok {
|
|
var err error
|
|
slot, err = c.allocSlot()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
c.slotForSeq[s] = slot
|
|
c.refCount[slot] = 1
|
|
newSlots = append(newSlots, slot)
|
|
}
|
|
c.curSlots = append(c.curSlots, slot)
|
|
}
|
|
|
|
// Zero conv state for newly allocated slots to clear stale data from previous sequences
|
|
if len(newSlots) > 0 {
|
|
c.zeroConvSlots(ctx, newSlots)
|
|
}
|
|
|
|
// Create a tensor for the current slots
|
|
slots := make([]int32, len(c.curSlots))
|
|
for i, v := range c.curSlots {
|
|
slots[i] = int32(v)
|
|
}
|
|
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
|
|
|
// Reset writable state for new forward pass
|
|
c.writableEnsured = false
|
|
c.writableError = nil
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *HybridCache) allocSlot() (int, error) {
|
|
if len(c.freeSlots) == 0 {
|
|
return 0, kvcache.ErrKvCacheFull
|
|
}
|
|
slot := c.freeSlots[len(c.freeSlots)-1]
|
|
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
|
|
return slot, nil
|
|
}
|
|
|
|
func (c *HybridCache) freeSlot(slot int) {
|
|
// Bounds check before freeing
|
|
if slot >= 0 && slot < c.maxSequences {
|
|
c.freeSlots = append(c.freeSlots, slot)
|
|
}
|
|
}
|
|
|
|
// zeroConvSlots zeros the conv state for the given slots across all layers.
|
|
// This must be called when recycling slots to prevent stale state from affecting new sequences.
|
|
func (c *HybridCache) zeroConvSlots(ctx ml.Context, slots []int) {
|
|
if len(slots) == 0 || len(c.convStates) == 0 {
|
|
return
|
|
}
|
|
|
|
// Use input context for creating tensors
|
|
inputCtx := ctx.Input()
|
|
|
|
// Create slot indices tensor
|
|
slotIndices := make([]int32, len(slots))
|
|
for i, s := range slots {
|
|
slotIndices[i] = int32(s)
|
|
}
|
|
slotsTensor := inputCtx.FromInts(slotIndices, len(slotIndices))
|
|
|
|
// Create zero tensor for the slots (SetRows requires F32 source)
|
|
zeros := inputCtx.Zeros(ml.DTypeF32, c.dConv*c.hiddenSize, len(slots))
|
|
|
|
// Zero each layer's conv state for these slots
|
|
for _, buf := range c.convStates {
|
|
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
|
}
|
|
}
|
|
|
|
// EnsureWritable ensures that sequences in the current batch have private (non-shared) conv slots.
|
|
// Returns an error if slot allocation fails.
|
|
func (c *HybridCache) EnsureWritable(ctx ml.Context) error {
|
|
for i, seq := range c.curSeqs {
|
|
slot, ok := c.slotForSeq[seq]
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
// Bounds check
|
|
if slot < 0 || slot >= len(c.refCount) {
|
|
continue
|
|
}
|
|
|
|
if c.refCount[slot] <= 1 {
|
|
continue
|
|
}
|
|
|
|
newSlot, err := c.allocSlot()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
c.refCount[slot]--
|
|
c.refCount[newSlot] = 1
|
|
c.slotForSeq[seq] = newSlot
|
|
c.curSlots[i] = newSlot
|
|
|
|
// Copy existing conv state for all initialized layers
|
|
for _, buf := range c.convStates {
|
|
// buf: [dConv*hiddenSize, maxSlots]
|
|
src := buf.Rows(ctx, ctx.Input().FromInts([]int32{int32(slot)}, 1))
|
|
// SetRows requires F32 source
|
|
srcF32 := src.Cast(ctx, ml.DTypeF32)
|
|
ctx.Forward(buf.SetRows(ctx, srcF32, ctx.Input().FromInts([]int32{int32(newSlot)}, 1)))
|
|
}
|
|
}
|
|
|
|
// Rebuild current slots tensor
|
|
slots := make([]int32, len(c.curSlots))
|
|
for i, v := range c.curSlots {
|
|
slots[i] = int32(v)
|
|
}
|
|
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *HybridCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
|
|
// KV cache shares prefix metadata (no copy) which is correct for prefix reuse.
|
|
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
|
|
|
|
// For shortconv state we implement copy-on-write: dst shares the same slot as src.
|
|
// On the first write to dst, EnsureWritable will create a private slot.
|
|
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
|
|
// Bounds check before decrementing
|
|
if dstSlot >= 0 && dstSlot < len(c.refCount) {
|
|
c.refCount[dstSlot]--
|
|
if c.refCount[dstSlot] <= 0 {
|
|
c.refCount[dstSlot] = 0
|
|
c.freeSlot(dstSlot)
|
|
}
|
|
}
|
|
delete(c.slotForSeq, dstSeq)
|
|
}
|
|
|
|
srcSlot, ok := c.slotForSeq[srcSeq]
|
|
if !ok {
|
|
// src may not have a slot yet; dst will allocate on demand
|
|
return
|
|
}
|
|
|
|
// Bounds check before incrementing
|
|
if srcSlot >= 0 && srcSlot < len(c.refCount) {
|
|
c.slotForSeq[dstSeq] = srcSlot
|
|
c.refCount[srcSlot]++
|
|
}
|
|
}
|
|
|
|
func (c *HybridCache) CanResume(seq int, pos int32) bool {
|
|
return c.kv.CanResume(seq, pos)
|
|
}
|
|
|
|
func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
|
|
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
|
|
return err
|
|
}
|
|
|
|
// For recurrent state, any removal invalidates the state because
|
|
// the state at position N depends on all previous positions.
|
|
// Drop the slot mapping so it resets on next use.
|
|
slot, ok := c.slotForSeq[seq]
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
// Bounds check
|
|
if slot < 0 || slot >= len(c.refCount) {
|
|
delete(c.slotForSeq, seq)
|
|
return nil
|
|
}
|
|
|
|
c.refCount[slot]--
|
|
if c.refCount[slot] <= 0 {
|
|
c.refCount[slot] = 0
|
|
c.freeSlot(slot)
|
|
}
|
|
delete(c.slotForSeq, seq)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *HybridCache) slotsTensor() ml.Tensor {
|
|
return c.curSlotsInput
|
|
}
|
|
|
|
func (c *HybridCache) seqTokens() int {
|
|
return c.curSeqTokens
|
|
}
|
|
|
|
func (c *HybridCache) numSeqs() int {
|
|
return len(c.curSeqs)
|
|
}
|
|
|
|
func (c *HybridCache) convBuffer(ctx ml.Context, layer int) ml.Tensor {
|
|
if buf, ok := c.convStates[layer]; ok {
|
|
return buf
|
|
}
|
|
|
|
if _, ok := c.convCtxs[layer]; !ok {
|
|
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
|
}
|
|
|
|
buf := c.convCtxs[layer].Zeros(c.dtype, c.dConv*c.hiddenSize, c.maxSequences)
|
|
c.convStates[layer] = buf
|
|
return buf
|
|
}
|
|
|
|
// ConvState returns the conv state for current batch sequences as shape [dConv, hiddenSize, nSeqs].
|
|
// Returns an error if copy-on-write allocation fails.
|
|
func (c *HybridCache) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
|
|
if !c.writableEnsured {
|
|
needsWritable := false
|
|
for _, seq := range c.curSeqs {
|
|
slot, ok := c.slotForSeq[seq]
|
|
if !ok {
|
|
continue
|
|
}
|
|
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
|
|
needsWritable = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if needsWritable {
|
|
if err := c.EnsureWritable(ctx); err != nil {
|
|
c.writableError = err
|
|
}
|
|
}
|
|
c.writableEnsured = true
|
|
}
|
|
|
|
if c.writableError != nil {
|
|
return nil, c.writableError
|
|
}
|
|
|
|
buf := c.convBuffer(ctx, layer)
|
|
cur := buf.Rows(ctx, c.slotsTensor())
|
|
return cur.Reshape(ctx, c.dConv, c.hiddenSize, c.numSeqs()), nil
|
|
}
|
|
|
|
// UpdateConvState writes a new conv state for current batch sequences.
|
|
// newState must have shape [dConv, hiddenSize, nSeqs].
|
|
func (c *HybridCache) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
|
|
buf := c.convBuffer(ctx, layer)
|
|
src := newState.Reshape(ctx, c.dConv*c.hiddenSize, c.numSeqs())
|
|
// SetRows requires F32 source
|
|
srcF32 := src.Cast(ctx, ml.DTypeF32)
|
|
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
|
|
}
|
|
|
|
// IsSupportedForBatch returns true if the current batch layout supports shortconv.
|
|
func (c *HybridCache) IsSupportedForBatch() bool {
|
|
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
|
|
}
|
|
|
|
// Seqs returns the ordered unique sequences for the current forward pass.
|
|
func (c *HybridCache) Seqs() []int {
|
|
return slices.Clone(c.curSeqs)
|
|
}
|