mirror of
https://github.com/ollama/ollama.git
synced 2026-01-29 07:12:03 +03:00
runner: add sync between computeBatch and completion
This commit is contained in:
@@ -62,6 +62,11 @@ type Sequence struct {
|
||||
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
||||
pendingResponses []string
|
||||
|
||||
// startGate
|
||||
startGate *sync.Mutex
|
||||
|
||||
grammarReady bool
|
||||
|
||||
// input cache being used by this sequence
|
||||
cache *InputCacheSlot
|
||||
|
||||
@@ -164,6 +169,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
|
||||
// TODO(jessegross): Ingest cached history for grammar
|
||||
|
||||
startGate := &sync.Mutex{}
|
||||
return &Sequence{
|
||||
ctxs: ctxs,
|
||||
mmStore: mmStore,
|
||||
@@ -179,6 +185,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
embeddingOnly: params.embedding,
|
||||
stop: params.stop,
|
||||
numKeep: params.numKeep,
|
||||
startGate: startGate,
|
||||
grammarReady: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -707,11 +715,18 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
// sample a token
|
||||
vocabSize := len(outputs) / len(activeBatch.batch.Outputs)
|
||||
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches)
|
||||
|
||||
if !seq.grammarReady {
|
||||
seq.startGate.Lock()
|
||||
}
|
||||
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
|
||||
if err != nil {
|
||||
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
|
||||
return
|
||||
}
|
||||
if !seq.grammarReady {
|
||||
seq.startGate.Unlock()
|
||||
}
|
||||
|
||||
nextBatchTokens[i].Token = token
|
||||
|
||||
@@ -830,11 +845,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
tokenParser := parser.NewTokenParser(req.ParserType, req.PrefillString)
|
||||
switch req.ParserType {
|
||||
case parser.TokenParserTypeHarmony:
|
||||
// Do not set grammar until model allows constraining
|
||||
default:
|
||||
seq.sampler.SetGrammar(grammar)
|
||||
// this accounts for the default case and also the case where there is a prefill which moves the state of the parser to allow for constraints
|
||||
if tokenParser.ConstraintsAllowed() {
|
||||
seq.grammarReady = true
|
||||
}
|
||||
|
||||
// Ensure there is a place to put the sequence, released when removed from s.seqs
|
||||
@@ -873,7 +886,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
grammarSet := false
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
@@ -881,6 +893,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
case content, ok := <-seq.responses:
|
||||
if ok {
|
||||
if !seq.grammarReady {
|
||||
seq.startGate.Lock()
|
||||
}
|
||||
var thinking string
|
||||
var err error
|
||||
content, thinking, err = tokenParser.AddContent(content)
|
||||
@@ -890,9 +905,11 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if !grammarSet && grammar != nil && tokenParser.ConstraintsAllowed() {
|
||||
seq.sampler.SetGrammar(grammar)
|
||||
grammarSet = true
|
||||
// only apply the grammar once
|
||||
if tokenParser.ConstraintsAllowed() && !seq.grammarReady {
|
||||
seq.sampler.SetGrammar(grammar, &s.mu)
|
||||
seq.grammarReady = true
|
||||
seq.startGate.Unlock()
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||
@@ -921,6 +938,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
return
|
||||
}
|
||||
if !seq.grammarReady {
|
||||
seq.startGate.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/model"
|
||||
@@ -25,7 +26,9 @@ type Sampler struct {
|
||||
grammar *GrammarSampler
|
||||
}
|
||||
|
||||
func (s *Sampler) SetGrammar(grammar *GrammarSampler) {
|
||||
func (s *Sampler) SetGrammar(grammar *GrammarSampler, mutex *sync.Mutex) {
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
s.grammar = grammar
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user