mirror of
https://github.com/ollama/ollama.git
synced 2026-01-29 07:12:03 +03:00
Compare commits
26 Commits
brucemacd/
...
f0d0cd8731
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f0d0cd8731 | ||
|
|
9638fda956 | ||
|
|
1a95093e0a | ||
|
|
6c32c58194 | ||
|
|
e7d9a389f9 | ||
|
|
04a31bc31d | ||
|
|
465d124183 | ||
|
|
d310e56fa3 | ||
|
|
a1ca428c90 | ||
|
|
16750865d1 | ||
|
|
f3b476c592 | ||
|
|
5267d31d56 | ||
|
|
b44f56319f | ||
|
|
0209c268bb | ||
|
|
912d984346 | ||
|
|
aae6ecbaff | ||
|
|
64737330a4 | ||
|
|
2eda97f1c3 | ||
|
|
66831dcf70 | ||
|
|
1044b0419a | ||
|
|
771d9280ec | ||
|
|
862bc0a3bf | ||
|
|
c01608b6a1 | ||
|
|
199c41e16e | ||
|
|
3b3bf6c217 | ||
|
|
f52c21f457 |
@@ -169,8 +169,10 @@ COPY . .
|
||||
RUN git clone --depth 1 --branch "$(cat MLX_VERSION)" https://github.com/ml-explore/mlx-c.git build/_deps/mlx-c-src
|
||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
ENV CGO_CFLAGS="-I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
|
||||
ARG CGO_CFLAGS
|
||||
ARG CGO_CXXFLAGS
|
||||
ENV CGO_CFLAGS="${CGO_CFLAGS} -I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
|
||||
ENV CGO_CXXFLAGS="${CGO_CXXFLAGS}"
|
||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||
go build -tags mlx -trimpath -buildmode=pie -o /bin/ollama .
|
||||
|
||||
|
||||
@@ -558,7 +558,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [LiteLLM](https://github.com/BerriAI/litellm)
|
||||
- [OllamaFarm for Go](https://github.com/presbrey/ollamafarm)
|
||||
- [OllamaSharp for .NET](https://github.com/awaescher/OllamaSharp)
|
||||
- [Ollama for Ruby](https://github.com/gbaptista/ollama-ai)
|
||||
- [Ollama for Ruby](https://github.com/crmne/ruby_llm)
|
||||
- [Ollama-rs for Rust](https://github.com/pepperoni21/ollama-rs)
|
||||
- [Ollama-hpp for C++](https://github.com/jmont-dev/ollama-hpp)
|
||||
- [Ollama4j for Java](https://github.com/ollama4j/ollama4j)
|
||||
|
||||
@@ -75,9 +75,9 @@ The `-dev` flag enables:
|
||||
CI builds with Xcode 14.1 for OS compatibility prior to v13. If you want to manually build v11+ support, you can download the older Xcode [here](https://developer.apple.com/services-account/download?path=/Developer_Tools/Xcode_14.1/Xcode_14.1.xip), extract, then `mv ./Xcode.app /Applications/Xcode_14.1.0.app` then activate with:
|
||||
|
||||
```
|
||||
export CGO_CFLAGS=-mmacosx-version-min=12.0
|
||||
export CGO_CXXFLAGS=-mmacosx-version-min=12.0
|
||||
export CGO_LDFLAGS=-mmacosx-version-min=12.0
|
||||
export CGO_CFLAGS="-O3 -mmacosx-version-min=12.0"
|
||||
export CGO_CXXFLAGS="-O3 -mmacosx-version-min=12.0"
|
||||
export CGO_LDFLAGS="-mmacosx-version-min=12.0"
|
||||
export SDKROOT=/Applications/Xcode_14.1.0.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk
|
||||
export DEVELOPER_DIR=/Applications/Xcode_14.1.0.app/Contents/Developer
|
||||
```
|
||||
|
||||
12
cmd/cmd.go
12
cmd/cmd.go
@@ -35,6 +35,7 @@ import (
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/parser"
|
||||
@@ -1018,8 +1019,10 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
||||
}
|
||||
|
||||
if resp.ModelInfo != nil {
|
||||
arch := resp.ModelInfo["general.architecture"].(string)
|
||||
rows = append(rows, []string{"", "architecture", arch})
|
||||
arch, _ := resp.ModelInfo["general.architecture"].(string)
|
||||
if arch != "" {
|
||||
rows = append(rows, []string{"", "architecture", arch})
|
||||
}
|
||||
|
||||
var paramStr string
|
||||
if resp.Details.ParameterSize != "" {
|
||||
@@ -1029,7 +1032,9 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
||||
paramStr = format.HumanNumber(uint64(f))
|
||||
}
|
||||
}
|
||||
rows = append(rows, []string{"", "parameters", paramStr})
|
||||
if paramStr != "" {
|
||||
rows = append(rows, []string{"", "parameters", paramStr})
|
||||
}
|
||||
|
||||
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok {
|
||||
if f, ok := v.(float64); ok {
|
||||
@@ -2026,6 +2031,7 @@ func NewCLI() *cobra.Command {
|
||||
copyCmd,
|
||||
deleteCmd,
|
||||
runnerCmd,
|
||||
config.LaunchCmd(checkServerHeartbeat),
|
||||
)
|
||||
|
||||
return rootCmd
|
||||
|
||||
@@ -1553,7 +1553,7 @@ func TestShowInfoImageGen(t *testing.T) {
|
||||
Details: api.ModelDetails{
|
||||
Family: "ZImagePipeline",
|
||||
ParameterSize: "10.3B",
|
||||
QuantizationLevel: "FP8",
|
||||
QuantizationLevel: "Q8",
|
||||
},
|
||||
Capabilities: []model.Capability{model.CapabilityImage},
|
||||
Requires: "0.14.0",
|
||||
@@ -1565,7 +1565,7 @@ func TestShowInfoImageGen(t *testing.T) {
|
||||
expect := " Model\n" +
|
||||
" architecture ZImagePipeline \n" +
|
||||
" parameters 10.3B \n" +
|
||||
" quantization FP8 \n" +
|
||||
" quantization Q8 \n" +
|
||||
" requires 0.14.0 \n" +
|
||||
"\n" +
|
||||
" Capabilities\n" +
|
||||
|
||||
58
cmd/config/claude.go
Normal file
58
cmd/config/claude.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// Claude implements Runner for Claude Code integration
|
||||
type Claude struct{}
|
||||
|
||||
func (c *Claude) String() string { return "Claude Code" }
|
||||
|
||||
func (c *Claude) args(model string) []string {
|
||||
if model != "" {
|
||||
return []string{"--model", model}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Claude) findPath() (string, error) {
|
||||
if p, err := exec.LookPath("claude"); err == nil {
|
||||
return p, nil
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
name := "claude"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "claude.exe"
|
||||
}
|
||||
fallback := filepath.Join(home, ".claude", "local", name)
|
||||
if _, err := os.Stat(fallback); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fallback, nil
|
||||
}
|
||||
|
||||
func (c *Claude) Run(model string) error {
|
||||
claudePath, err := c.findPath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
|
||||
}
|
||||
|
||||
cmd := exec.Command(claudePath, c.args(model)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Env = append(os.Environ(),
|
||||
"ANTHROPIC_BASE_URL=http://localhost:11434",
|
||||
"ANTHROPIC_API_KEY=",
|
||||
"ANTHROPIC_AUTH_TOKEN=ollama",
|
||||
)
|
||||
return cmd.Run()
|
||||
}
|
||||
101
cmd/config/claude_test.go
Normal file
101
cmd/config/claude_test.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClaudeIntegration(t *testing.T) {
|
||||
c := &Claude{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := c.String(); got != "Claude Code" {
|
||||
t.Errorf("String() = %q, want %q", got, "Claude Code")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = c
|
||||
})
|
||||
}
|
||||
|
||||
func TestClaudeFindPath(t *testing.T) {
|
||||
c := &Claude{}
|
||||
|
||||
t.Run("finds claude in PATH", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
name := "claude"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "claude.exe"
|
||||
}
|
||||
fakeBin := filepath.Join(tmpDir, name)
|
||||
os.WriteFile(fakeBin, []byte("#!/bin/sh\n"), 0o755)
|
||||
t.Setenv("PATH", tmpDir)
|
||||
|
||||
got, err := c.findPath()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != fakeBin {
|
||||
t.Errorf("findPath() = %q, want %q", got, fakeBin)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back to ~/.claude/local/claude", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", t.TempDir()) // empty dir, no claude binary
|
||||
|
||||
name := "claude"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "claude.exe"
|
||||
}
|
||||
fallback := filepath.Join(tmpDir, ".claude", "local", name)
|
||||
os.MkdirAll(filepath.Dir(fallback), 0o755)
|
||||
os.WriteFile(fallback, []byte("#!/bin/sh\n"), 0o755)
|
||||
|
||||
got, err := c.findPath()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != fallback {
|
||||
t.Errorf("findPath() = %q, want %q", got, fallback)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error when neither PATH nor fallback exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", t.TempDir()) // empty dir, no claude binary
|
||||
|
||||
_, err := c.findPath()
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClaudeArgs(t *testing.T) {
|
||||
c := &Claude{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
want []string
|
||||
}{
|
||||
{"with model", "llama3.2", []string{"--model", "llama3.2"}},
|
||||
{"empty model", "", nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := c.args(tt.model)
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
61
cmd/config/codex.go
Normal file
61
cmd/config/codex.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/mod/semver"
|
||||
)
|
||||
|
||||
// Codex implements Runner for Codex integration
|
||||
type Codex struct{}
|
||||
|
||||
func (c *Codex) String() string { return "Codex" }
|
||||
|
||||
func (c *Codex) args(model string) []string {
|
||||
args := []string{"--oss"}
|
||||
if model != "" {
|
||||
args = append(args, "-m", model)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func (c *Codex) Run(model string) error {
|
||||
if err := checkCodexVersion(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd := exec.Command("codex", c.args(model)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func checkCodexVersion() error {
|
||||
if _, err := exec.LookPath("codex"); err != nil {
|
||||
return fmt.Errorf("codex is not installed, install with: npm install -g @openai/codex")
|
||||
}
|
||||
|
||||
out, err := exec.Command("codex", "--version").Output()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get codex version: %w", err)
|
||||
}
|
||||
|
||||
// Parse output like "codex-cli 0.87.0"
|
||||
fields := strings.Fields(strings.TrimSpace(string(out)))
|
||||
if len(fields) < 2 {
|
||||
return fmt.Errorf("unexpected codex version output: %s", string(out))
|
||||
}
|
||||
|
||||
version := "v" + fields[len(fields)-1]
|
||||
minVersion := "v0.81.0"
|
||||
|
||||
if semver.Compare(version, minVersion) < 0 {
|
||||
return fmt.Errorf("codex version %s is too old, minimum required is %s, update with: npm update -g @openai/codex", fields[len(fields)-1], "0.81.0")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
28
cmd/config/codex_test.go
Normal file
28
cmd/config/codex_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCodexArgs(t *testing.T) {
|
||||
c := &Codex{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
want []string
|
||||
}{
|
||||
{"with model", "llama3.2", []string{"--oss", "-m", "llama3.2"}},
|
||||
{"empty model", "", []string{"--oss"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := c.args(tt.model)
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
115
cmd/config/config.go
Normal file
115
cmd/config/config.go
Normal file
@@ -0,0 +1,115 @@
|
||||
// Package config provides integration configuration for external coding tools
|
||||
// (Claude Code, Codex, Droid, OpenCode) to use Ollama models.
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type integration struct {
|
||||
Models []string `json:"models"`
|
||||
}
|
||||
|
||||
type config struct {
|
||||
Integrations map[string]*integration `json:"integrations"`
|
||||
}
|
||||
|
||||
func configPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(home, ".ollama", "config", "config.json"), nil
|
||||
}
|
||||
|
||||
func load() (*config, error) {
|
||||
path, err := configPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return &config{Integrations: make(map[string]*integration)}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cfg config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config: %w, at: %s", err, path)
|
||||
}
|
||||
if cfg.Integrations == nil {
|
||||
cfg.Integrations = make(map[string]*integration)
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func save(cfg *config) error {
|
||||
path, err := configPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return writeWithBackup(path, data)
|
||||
}
|
||||
|
||||
func saveIntegration(appName string, models []string) error {
|
||||
if appName == "" {
|
||||
return errors.New("app name cannot be empty")
|
||||
}
|
||||
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg.Integrations[strings.ToLower(appName)] = &integration{
|
||||
Models: models,
|
||||
}
|
||||
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
func loadIntegration(appName string) (*integration, error) {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ic, ok := cfg.Integrations[strings.ToLower(appName)]
|
||||
if !ok {
|
||||
return nil, os.ErrNotExist
|
||||
}
|
||||
|
||||
return ic, nil
|
||||
}
|
||||
|
||||
func listIntegrations() ([]integration, error) {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]integration, 0, len(cfg.Integrations))
|
||||
for _, ic := range cfg.Integrations {
|
||||
result = append(result, *ic)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
373
cmd/config/config_test.go
Normal file
373
cmd/config/config_test.go
Normal file
@@ -0,0 +1,373 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// setTestHome sets both HOME (Unix) and USERPROFILE (Windows) for cross-platform tests
|
||||
func setTestHome(t *testing.T, dir string) {
|
||||
t.Setenv("HOME", dir)
|
||||
t.Setenv("USERPROFILE", dir)
|
||||
}
|
||||
|
||||
// editorPaths is a test helper that safely calls Paths if the runner implements Editor
|
||||
func editorPaths(r Runner) []string {
|
||||
if editor, ok := r.(Editor); ok {
|
||||
return editor.Paths()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestIntegrationConfig(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("save and load round-trip", func(t *testing.T) {
|
||||
models := []string{"llama3.2", "mistral", "qwen2.5"}
|
||||
if err := saveIntegration("claude", models); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(config.Models) != len(models) {
|
||||
t.Errorf("expected %d models, got %d", len(models), len(config.Models))
|
||||
}
|
||||
for i, m := range models {
|
||||
if config.Models[i] != m {
|
||||
t.Errorf("model %d: expected %s, got %s", i, m, config.Models[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("defaultModel returns first model", func(t *testing.T) {
|
||||
saveIntegration("codex", []string{"model-a", "model-b"})
|
||||
|
||||
config, _ := loadIntegration("codex")
|
||||
defaultModel := ""
|
||||
if len(config.Models) > 0 {
|
||||
defaultModel = config.Models[0]
|
||||
}
|
||||
if defaultModel != "model-a" {
|
||||
t.Errorf("expected model-a, got %s", defaultModel)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("defaultModel returns empty for no models", func(t *testing.T) {
|
||||
config := &integration{Models: []string{}}
|
||||
defaultModel := ""
|
||||
if len(config.Models) > 0 {
|
||||
defaultModel = config.Models[0]
|
||||
}
|
||||
if defaultModel != "" {
|
||||
t.Errorf("expected empty string, got %s", defaultModel)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("app name is case-insensitive", func(t *testing.T) {
|
||||
saveIntegration("Claude", []string{"model-x"})
|
||||
|
||||
config, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defaultModel := ""
|
||||
if len(config.Models) > 0 {
|
||||
defaultModel = config.Models[0]
|
||||
}
|
||||
if defaultModel != "model-x" {
|
||||
t.Errorf("expected model-x, got %s", defaultModel)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multiple integrations in single file", func(t *testing.T) {
|
||||
saveIntegration("app1", []string{"model-1"})
|
||||
saveIntegration("app2", []string{"model-2"})
|
||||
|
||||
config1, _ := loadIntegration("app1")
|
||||
config2, _ := loadIntegration("app2")
|
||||
|
||||
defaultModel1 := ""
|
||||
if len(config1.Models) > 0 {
|
||||
defaultModel1 = config1.Models[0]
|
||||
}
|
||||
defaultModel2 := ""
|
||||
if len(config2.Models) > 0 {
|
||||
defaultModel2 = config2.Models[0]
|
||||
}
|
||||
if defaultModel1 != "model-1" {
|
||||
t.Errorf("expected model-1, got %s", defaultModel1)
|
||||
}
|
||||
if defaultModel2 != "model-2" {
|
||||
t.Errorf("expected model-2, got %s", defaultModel2)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestListIntegrations(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("returns empty when no integrations", func(t *testing.T) {
|
||||
configs, err := listIntegrations()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(configs) != 0 {
|
||||
t.Errorf("expected 0 integrations, got %d", len(configs))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns all saved integrations", func(t *testing.T) {
|
||||
saveIntegration("claude", []string{"model-1"})
|
||||
saveIntegration("droid", []string{"model-2"})
|
||||
|
||||
configs, err := listIntegrations()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(configs) != 2 {
|
||||
t.Errorf("expected 2 integrations, got %d", len(configs))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestEditorPaths(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("returns empty for claude (no Editor)", func(t *testing.T) {
|
||||
r := integrations["claude"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 0 {
|
||||
t.Errorf("expected no paths for claude, got %v", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns empty for codex (no Editor)", func(t *testing.T) {
|
||||
r := integrations["codex"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 0 {
|
||||
t.Errorf("expected no paths for codex, got %v", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns empty for droid when no config exists", func(t *testing.T) {
|
||||
r := integrations["droid"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 0 {
|
||||
t.Errorf("expected no paths, got %v", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns path for droid when config exists", func(t *testing.T) {
|
||||
settingsDir, _ := os.UserHomeDir()
|
||||
settingsDir = filepath.Join(settingsDir, ".factory")
|
||||
os.MkdirAll(settingsDir, 0o755)
|
||||
os.WriteFile(filepath.Join(settingsDir, "settings.json"), []byte(`{}`), 0o644)
|
||||
|
||||
r := integrations["droid"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 1 {
|
||||
t.Errorf("expected 1 path, got %d", len(paths))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns paths for opencode when configs exist", func(t *testing.T) {
|
||||
home, _ := os.UserHomeDir()
|
||||
configDir := filepath.Join(home, ".config", "opencode")
|
||||
stateDir := filepath.Join(home, ".local", "state", "opencode")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "opencode.json"), []byte(`{}`), 0o644)
|
||||
os.WriteFile(filepath.Join(stateDir, "model.json"), []byte(`{}`), 0o644)
|
||||
|
||||
r := integrations["opencode"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 2 {
|
||||
t.Errorf("expected 2 paths, got %d: %v", len(paths), paths)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadIntegration_CorruptedJSON(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Create corrupted config.json file
|
||||
dir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(dir, 0o755)
|
||||
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
|
||||
|
||||
// Corrupted file is treated as empty, so loadIntegration returns not found
|
||||
_, err := loadIntegration("test")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent integration in corrupted file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveIntegration_NilModels(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
if err := saveIntegration("test", nil); err != nil {
|
||||
t.Fatalf("saveIntegration with nil models failed: %v", err)
|
||||
}
|
||||
|
||||
config, err := loadIntegration("test")
|
||||
if err != nil {
|
||||
t.Fatalf("loadIntegration failed: %v", err)
|
||||
}
|
||||
|
||||
if config.Models == nil {
|
||||
// nil is acceptable
|
||||
} else if len(config.Models) != 0 {
|
||||
t.Errorf("expected empty or nil models, got %v", config.Models)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveIntegration_EmptyAppName(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
err := saveIntegration("", []string{"model"})
|
||||
if err == nil {
|
||||
t.Error("expected error for empty app name, got nil")
|
||||
}
|
||||
if err != nil && !strings.Contains(err.Error(), "app name cannot be empty") {
|
||||
t.Errorf("expected 'app name cannot be empty' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadIntegration_NonexistentIntegration(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
_, err := loadIntegration("nonexistent")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent integration, got nil")
|
||||
}
|
||||
if !os.IsNotExist(err) {
|
||||
t.Logf("error type is os.ErrNotExist as expected: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigPath(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
path, err := configPath()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected := filepath.Join(tmpDir, ".ollama", "config", "config.json")
|
||||
if path != expected {
|
||||
t.Errorf("expected %s, got %s", expected, path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("returns empty config when file does not exist", func(t *testing.T) {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if cfg == nil {
|
||||
t.Fatal("expected non-nil config")
|
||||
}
|
||||
if cfg.Integrations == nil {
|
||||
t.Error("expected non-nil Integrations map")
|
||||
}
|
||||
if len(cfg.Integrations) != 0 {
|
||||
t.Errorf("expected empty Integrations, got %d", len(cfg.Integrations))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("loads existing config", func(t *testing.T) {
|
||||
path, _ := configPath()
|
||||
os.MkdirAll(filepath.Dir(path), 0o755)
|
||||
os.WriteFile(path, []byte(`{"integrations":{"test":{"models":["model-a"]}}}`), 0o644)
|
||||
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if cfg.Integrations["test"] == nil {
|
||||
t.Fatal("expected test integration")
|
||||
}
|
||||
if len(cfg.Integrations["test"].Models) != 1 {
|
||||
t.Errorf("expected 1 model, got %d", len(cfg.Integrations["test"].Models))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error for corrupted JSON", func(t *testing.T) {
|
||||
path, _ := configPath()
|
||||
os.MkdirAll(filepath.Dir(path), 0o755)
|
||||
os.WriteFile(path, []byte(`{corrupted`), 0o644)
|
||||
|
||||
_, err := load()
|
||||
if err == nil {
|
||||
t.Error("expected error for corrupted JSON")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSave(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("creates config file", func(t *testing.T) {
|
||||
cfg := &config{
|
||||
Integrations: map[string]*integration{
|
||||
"test": {Models: []string{"model-a", "model-b"}},
|
||||
},
|
||||
}
|
||||
|
||||
if err := save(cfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
path, _ := configPath()
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
t.Error("config file was not created")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("round-trip preserves data", func(t *testing.T) {
|
||||
cfg := &config{
|
||||
Integrations: map[string]*integration{
|
||||
"claude": {Models: []string{"llama3.2", "mistral"}},
|
||||
"codex": {Models: []string{"qwen2.5"}},
|
||||
},
|
||||
}
|
||||
|
||||
if err := save(cfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, err := load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(loaded.Integrations) != 2 {
|
||||
t.Errorf("expected 2 integrations, got %d", len(loaded.Integrations))
|
||||
}
|
||||
if loaded.Integrations["claude"] == nil {
|
||||
t.Error("missing claude integration")
|
||||
}
|
||||
if len(loaded.Integrations["claude"].Models) != 2 {
|
||||
t.Errorf("expected 2 models for claude, got %d", len(loaded.Integrations["claude"].Models))
|
||||
}
|
||||
})
|
||||
}
|
||||
184
cmd/config/droid.go
Normal file
184
cmd/config/droid.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
)
|
||||
|
||||
// Droid implements Runner and Editor for Droid integration
|
||||
type Droid struct{}
|
||||
|
||||
// droidSettings represents the Droid settings.json file (only fields we use)
|
||||
type droidSettings struct {
|
||||
CustomModels []modelEntry `json:"customModels"`
|
||||
SessionDefaultSettings sessionSettings `json:"sessionDefaultSettings"`
|
||||
}
|
||||
|
||||
type sessionSettings struct {
|
||||
Model string `json:"model"`
|
||||
ReasoningEffort string `json:"reasoningEffort"`
|
||||
}
|
||||
|
||||
type modelEntry struct {
|
||||
Model string `json:"model"`
|
||||
DisplayName string `json:"displayName"`
|
||||
BaseURL string `json:"baseUrl"`
|
||||
APIKey string `json:"apiKey"`
|
||||
Provider string `json:"provider"`
|
||||
MaxOutputTokens int `json:"maxOutputTokens"`
|
||||
SupportsImages bool `json:"supportsImages"`
|
||||
ID string `json:"id"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
func (d *Droid) String() string { return "Droid" }
|
||||
|
||||
func (d *Droid) Run(model string) error {
|
||||
if _, err := exec.LookPath("droid"); err != nil {
|
||||
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
|
||||
}
|
||||
|
||||
// Call Edit() to ensure config is up-to-date before launch
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("droid"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
if err := d.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("droid")
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func (d *Droid) Paths() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
p := filepath.Join(home, ".factory", "settings.json")
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return []string{p}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Droid) Edit(models []string) error {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
settingsPath := filepath.Join(home, ".factory", "settings.json")
|
||||
if err := os.MkdirAll(filepath.Dir(settingsPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read file once, unmarshal twice:
|
||||
// map preserves unknown fields for writing back (including extra fields in model entries)
|
||||
settingsMap := make(map[string]any)
|
||||
var settings droidSettings
|
||||
if data, err := os.ReadFile(settingsPath); err == nil {
|
||||
if err := json.Unmarshal(data, &settingsMap); err != nil {
|
||||
return fmt.Errorf("failed to parse settings file: %w, at: %s", err, settingsPath)
|
||||
}
|
||||
json.Unmarshal(data, &settings) // ignore error, zero values are fine
|
||||
}
|
||||
|
||||
// Keep only non-Ollama models from the raw map (preserves extra fields)
|
||||
// Rebuild Ollama models
|
||||
var nonOllamaModels []any
|
||||
if rawModels, ok := settingsMap["customModels"].([]any); ok {
|
||||
for _, raw := range rawModels {
|
||||
if m, ok := raw.(map[string]any); ok {
|
||||
if m["apiKey"] != "ollama" {
|
||||
nonOllamaModels = append(nonOllamaModels, raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
|
||||
var newModels []any
|
||||
var defaultModelID string
|
||||
for i, model := range models {
|
||||
modelID := fmt.Sprintf("custom:%s-%d", model, i)
|
||||
newModels = append(newModels, modelEntry{
|
||||
Model: model,
|
||||
DisplayName: model,
|
||||
BaseURL: "http://localhost:11434/v1",
|
||||
APIKey: "ollama",
|
||||
Provider: "generic-chat-completion-api",
|
||||
MaxOutputTokens: 64000,
|
||||
SupportsImages: false,
|
||||
ID: modelID,
|
||||
Index: i,
|
||||
})
|
||||
if i == 0 {
|
||||
defaultModelID = modelID
|
||||
}
|
||||
}
|
||||
|
||||
settingsMap["customModels"] = append(newModels, nonOllamaModels...)
|
||||
|
||||
// Update session default settings (preserve unknown fields in the nested object)
|
||||
sessionSettings, ok := settingsMap["sessionDefaultSettings"].(map[string]any)
|
||||
if !ok {
|
||||
sessionSettings = make(map[string]any)
|
||||
}
|
||||
sessionSettings["model"] = defaultModelID
|
||||
|
||||
if !isValidReasoningEffort(settings.SessionDefaultSettings.ReasoningEffort) {
|
||||
sessionSettings["reasoningEffort"] = "none"
|
||||
}
|
||||
|
||||
settingsMap["sessionDefaultSettings"] = sessionSettings
|
||||
|
||||
data, err := json.MarshalIndent(settingsMap, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(settingsPath, data)
|
||||
}
|
||||
|
||||
func (d *Droid) Models() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(home, ".factory", "settings.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var settings droidSettings
|
||||
if err := json.Unmarshal(data, &settings); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var result []string
|
||||
for _, m := range settings.CustomModels {
|
||||
if m.APIKey == "ollama" {
|
||||
result = append(result, m.Model)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
var validReasoningEfforts = []string{"high", "medium", "low", "none"}
|
||||
|
||||
func isValidReasoningEffort(effort string) bool {
|
||||
return slices.Contains(validReasoningEfforts, effort)
|
||||
}
|
||||
1302
cmd/config/droid_test.go
Normal file
1302
cmd/config/droid_test.go
Normal file
File diff suppressed because it is too large
Load Diff
99
cmd/config/files.go
Normal file
99
cmd/config/files.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
func readJSONFile(path string) (map[string]any, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func copyFile(src, dst string) error {
|
||||
info, err := os.Stat(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := os.ReadFile(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(dst, data, info.Mode().Perm())
|
||||
}
|
||||
|
||||
func backupDir() string {
|
||||
return filepath.Join(os.TempDir(), "ollama-backups")
|
||||
}
|
||||
|
||||
func backupToTmp(srcPath string) (string, error) {
|
||||
dir := backupDir()
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
backupPath := filepath.Join(dir, fmt.Sprintf("%s.%d", filepath.Base(srcPath), time.Now().Unix()))
|
||||
if err := copyFile(srcPath, backupPath); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return backupPath, nil
|
||||
}
|
||||
|
||||
// writeWithBackup writes data to path via temp file + rename, backing up any existing file first
|
||||
func writeWithBackup(path string, data []byte) error {
|
||||
var backupPath string
|
||||
// backup must be created before any writes to the target file
|
||||
if existingContent, err := os.ReadFile(path); err == nil {
|
||||
if !bytes.Equal(existingContent, data) {
|
||||
backupPath, err = backupToTmp(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("backup failed: %w", err)
|
||||
}
|
||||
}
|
||||
} else if !os.IsNotExist(err) {
|
||||
return fmt.Errorf("read existing file: %w", err)
|
||||
}
|
||||
|
||||
dir := filepath.Dir(path)
|
||||
tmp, err := os.CreateTemp(dir, ".tmp-*")
|
||||
if err != nil {
|
||||
return fmt.Errorf("create temp failed: %w", err)
|
||||
}
|
||||
tmpPath := tmp.Name()
|
||||
|
||||
if _, err := tmp.Write(data); err != nil {
|
||||
_ = tmp.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
return fmt.Errorf("write failed: %w", err)
|
||||
}
|
||||
if err := tmp.Sync(); err != nil {
|
||||
_ = tmp.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
return fmt.Errorf("sync failed: %w", err)
|
||||
}
|
||||
if err := tmp.Close(); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return fmt.Errorf("close failed: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Rename(tmpPath, path); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
if backupPath != "" {
|
||||
_ = copyFile(backupPath, path)
|
||||
}
|
||||
return fmt.Errorf("rename failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
502
cmd/config/files_test.go
Normal file
502
cmd/config/files_test.go
Normal file
@@ -0,0 +1,502 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func mustMarshal(t *testing.T, v any) []byte {
|
||||
t.Helper()
|
||||
data, err := json.MarshalIndent(v, "", " ")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func TestWriteWithBackup(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
t.Run("creates file", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "new.json")
|
||||
data := mustMarshal(t, map[string]string{"key": "value"})
|
||||
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var result map[string]string
|
||||
if err := json.Unmarshal(content, &result); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result["key"] != "value" {
|
||||
t.Errorf("expected value, got %s", result["key"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("creates backup in /tmp/ollama-backups", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "backup.json")
|
||||
|
||||
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||
|
||||
data := mustMarshal(t, map[string]bool{"updated": true})
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(backupDir())
|
||||
if err != nil {
|
||||
t.Fatal("backup directory not created")
|
||||
}
|
||||
|
||||
var foundBackup bool
|
||||
for _, entry := range entries {
|
||||
if filepath.Ext(entry.Name()) != ".json" {
|
||||
name := entry.Name()
|
||||
if len(name) > len("backup.json.") && name[:len("backup.json.")] == "backup.json." {
|
||||
backupPath := filepath.Join(backupDir(), name)
|
||||
backup, err := os.ReadFile(backupPath)
|
||||
if err == nil {
|
||||
var backupData map[string]bool
|
||||
json.Unmarshal(backup, &backupData)
|
||||
if backupData["original"] {
|
||||
foundBackup = true
|
||||
os.Remove(backupPath)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundBackup {
|
||||
t.Error("backup file not created in /tmp/ollama-backups")
|
||||
}
|
||||
|
||||
current, _ := os.ReadFile(path)
|
||||
var currentData map[string]bool
|
||||
json.Unmarshal(current, ¤tData)
|
||||
if !currentData["updated"] {
|
||||
t.Error("file doesn't contain updated data")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no backup for new file", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "nobak.json")
|
||||
|
||||
data := mustMarshal(t, map[string]string{"new": "file"})
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries, _ := os.ReadDir(backupDir())
|
||||
for _, entry := range entries {
|
||||
if len(entry.Name()) > len("nobak.json.") && entry.Name()[:len("nobak.json.")] == "nobak.json." {
|
||||
t.Error("backup should not exist for new file")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no backup when content unchanged", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "unchanged.json")
|
||||
|
||||
data := mustMarshal(t, map[string]string{"key": "value"})
|
||||
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries1, _ := os.ReadDir(backupDir())
|
||||
countBefore := 0
|
||||
for _, e := range entries1 {
|
||||
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
||||
countBefore++
|
||||
}
|
||||
}
|
||||
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries2, _ := os.ReadDir(backupDir())
|
||||
countAfter := 0
|
||||
for _, e := range entries2 {
|
||||
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
||||
countAfter++
|
||||
}
|
||||
}
|
||||
|
||||
if countAfter != countBefore {
|
||||
t.Errorf("backup was created when content unchanged (before=%d, after=%d)", countBefore, countAfter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("backup filename contains unix timestamp", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "timestamped.json")
|
||||
|
||||
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
|
||||
data := mustMarshal(t, map[string]int{"v": 2})
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries, _ := os.ReadDir(backupDir())
|
||||
var found bool
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
if len(name) > len("timestamped.json.") && name[:len("timestamped.json.")] == "timestamped.json." {
|
||||
timestamp := name[len("timestamped.json."):]
|
||||
for _, c := range timestamp {
|
||||
if c < '0' || c > '9' {
|
||||
t.Errorf("backup filename timestamp contains non-numeric character: %s", name)
|
||||
}
|
||||
}
|
||||
found = true
|
||||
os.Remove(filepath.Join(backupDir(), name))
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("backup file with timestamp not found")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Edge case tests for files.go
|
||||
|
||||
// TestWriteWithBackup_FailsIfBackupFails documents critical behavior: if backup fails, we must not proceed.
|
||||
// User could lose their config with no way to recover.
|
||||
func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "config.json")
|
||||
|
||||
// Create original file
|
||||
originalContent := []byte(`{"original": true}`)
|
||||
os.WriteFile(path, originalContent, 0o644)
|
||||
|
||||
// Make backup directory read-only to force backup failure
|
||||
backupDir := backupDir()
|
||||
os.MkdirAll(backupDir, 0o755)
|
||||
os.Chmod(backupDir, 0o444) // Read-only
|
||||
defer os.Chmod(backupDir, 0o755)
|
||||
|
||||
newContent := []byte(`{"updated": true}`)
|
||||
err := writeWithBackup(path, newContent)
|
||||
|
||||
// Should fail because backup couldn't be created
|
||||
if err == nil {
|
||||
t.Error("expected error when backup fails, got nil")
|
||||
}
|
||||
|
||||
// Original file should be preserved
|
||||
current, _ := os.ReadFile(path)
|
||||
if string(current) != string(originalContent) {
|
||||
t.Errorf("original file was modified despite backup failure: got %s", string(current))
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_PermissionDenied verifies clear error when target file has wrong permissions.
|
||||
// Common issue when config owned by root or wrong perms.
|
||||
func TestWriteWithBackup_PermissionDenied(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Create a read-only directory
|
||||
readOnlyDir := filepath.Join(tmpDir, "readonly")
|
||||
os.MkdirAll(readOnlyDir, 0o755)
|
||||
os.Chmod(readOnlyDir, 0o444)
|
||||
defer os.Chmod(readOnlyDir, 0o755)
|
||||
|
||||
path := filepath.Join(readOnlyDir, "config.json")
|
||||
err := writeWithBackup(path, []byte(`{"test": true}`))
|
||||
|
||||
if err == nil {
|
||||
t.Error("expected permission error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_DirectoryDoesNotExist verifies behavior when target directory doesn't exist.
|
||||
// writeWithBackup doesn't create directories - caller is responsible.
|
||||
func TestWriteWithBackup_DirectoryDoesNotExist(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "nonexistent", "subdir", "config.json")
|
||||
|
||||
err := writeWithBackup(path, []byte(`{"test": true}`))
|
||||
|
||||
// Should fail because directory doesn't exist
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent directory, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_SymlinkTarget documents behavior when target is a symlink.
|
||||
// Documents what happens if user symlinks their config file.
|
||||
func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("symlink tests may require admin on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
realFile := filepath.Join(tmpDir, "real.json")
|
||||
symlink := filepath.Join(tmpDir, "link.json")
|
||||
|
||||
// Create real file and symlink
|
||||
os.WriteFile(realFile, []byte(`{"v": 1}`), 0o644)
|
||||
os.Symlink(realFile, symlink)
|
||||
|
||||
// Write through symlink
|
||||
err := writeWithBackup(symlink, []byte(`{"v": 2}`))
|
||||
if err != nil {
|
||||
t.Fatalf("writeWithBackup through symlink failed: %v", err)
|
||||
}
|
||||
|
||||
// The real file should be updated (symlink followed for temp file creation)
|
||||
content, _ := os.ReadFile(symlink)
|
||||
if string(content) != `{"v": 2}` {
|
||||
t.Errorf("symlink target not updated correctly: got %s", string(content))
|
||||
}
|
||||
}
|
||||
|
||||
// TestBackupToTmp_SpecialCharsInFilename verifies backup works with special characters.
|
||||
// User may have config files with unusual names.
|
||||
func TestBackupToTmp_SpecialCharsInFilename(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// File with spaces and special chars
|
||||
path := filepath.Join(tmpDir, "my config (backup).json")
|
||||
os.WriteFile(path, []byte(`{"test": true}`), 0o644)
|
||||
|
||||
backupPath, err := backupToTmp(path)
|
||||
if err != nil {
|
||||
t.Fatalf("backupToTmp with special chars failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify backup exists and has correct content
|
||||
content, err := os.ReadFile(backupPath)
|
||||
if err != nil {
|
||||
t.Fatalf("could not read backup: %v", err)
|
||||
}
|
||||
if string(content) != `{"test": true}` {
|
||||
t.Errorf("backup content mismatch: got %s", string(content))
|
||||
}
|
||||
|
||||
os.Remove(backupPath)
|
||||
}
|
||||
|
||||
// TestCopyFile_PreservesPermissions verifies that copyFile preserves file permissions.
|
||||
func TestCopyFile_PreservesPermissions(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission preservation tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
src := filepath.Join(tmpDir, "src.json")
|
||||
dst := filepath.Join(tmpDir, "dst.json")
|
||||
|
||||
// Create source with specific permissions
|
||||
os.WriteFile(src, []byte(`{"test": true}`), 0o600)
|
||||
|
||||
err := copyFile(src, dst)
|
||||
if err != nil {
|
||||
t.Fatalf("copyFile failed: %v", err)
|
||||
}
|
||||
|
||||
srcInfo, _ := os.Stat(src)
|
||||
dstInfo, _ := os.Stat(dst)
|
||||
|
||||
if srcInfo.Mode().Perm() != dstInfo.Mode().Perm() {
|
||||
t.Errorf("permissions not preserved: src=%v, dst=%v", srcInfo.Mode().Perm(), dstInfo.Mode().Perm())
|
||||
}
|
||||
}
|
||||
|
||||
// TestCopyFile_SourceNotFound verifies clear error when source doesn't exist.
|
||||
func TestCopyFile_SourceNotFound(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
src := filepath.Join(tmpDir, "nonexistent.json")
|
||||
dst := filepath.Join(tmpDir, "dst.json")
|
||||
|
||||
err := copyFile(src, dst)
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent source, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_TargetIsDirectory verifies error when path points to a directory.
|
||||
func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
dirPath := filepath.Join(tmpDir, "actualdir")
|
||||
os.MkdirAll(dirPath, 0o755)
|
||||
|
||||
err := writeWithBackup(dirPath, []byte(`{"test": true}`))
|
||||
if err == nil {
|
||||
t.Error("expected error when target is a directory, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_EmptyData verifies writing zero bytes works correctly.
|
||||
func TestWriteWithBackup_EmptyData(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "empty.json")
|
||||
|
||||
err := writeWithBackup(path, []byte{})
|
||||
if err != nil {
|
||||
t.Fatalf("writeWithBackup with empty data failed: %v", err)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("could not read file: %v", err)
|
||||
}
|
||||
if len(content) != 0 {
|
||||
t.Errorf("expected empty file, got %d bytes", len(content))
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_FileUnreadableButDirWritable verifies behavior when existing file
|
||||
// cannot be read (for backup comparison) but directory is writable.
|
||||
func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "unreadable.json")
|
||||
|
||||
// Create file and make it unreadable
|
||||
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||
os.Chmod(path, 0o000)
|
||||
defer os.Chmod(path, 0o644)
|
||||
|
||||
// Should fail because we can't read the file to compare/backup
|
||||
err := writeWithBackup(path, []byte(`{"updated": true}`))
|
||||
if err == nil {
|
||||
t.Error("expected error when file is unreadable, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_RapidSuccessiveWrites verifies backup works with multiple writes
|
||||
// within the same second (timestamp collision scenario).
|
||||
func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "rapid.json")
|
||||
|
||||
// Create initial file
|
||||
os.WriteFile(path, []byte(`{"v": 0}`), 0o644)
|
||||
|
||||
// Rapid successive writes
|
||||
for i := 1; i <= 3; i++ {
|
||||
data := []byte(fmt.Sprintf(`{"v": %d}`, i))
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatalf("write %d failed: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify final content
|
||||
content, _ := os.ReadFile(path)
|
||||
if string(content) != `{"v": 3}` {
|
||||
t.Errorf("expected final content {\"v\": 3}, got %s", string(content))
|
||||
}
|
||||
|
||||
// Verify at least one backup exists
|
||||
entries, _ := os.ReadDir(backupDir())
|
||||
var backupCount int
|
||||
for _, e := range entries {
|
||||
if len(e.Name()) > len("rapid.json.") && e.Name()[:len("rapid.json.")] == "rapid.json." {
|
||||
backupCount++
|
||||
}
|
||||
}
|
||||
if backupCount == 0 {
|
||||
t.Error("expected at least one backup file from rapid writes")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_BackupDirIsFile verifies error when backup directory path is a file.
|
||||
func TestWriteWithBackup_BackupDirIsFile(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("test modifies system temp directory")
|
||||
}
|
||||
|
||||
// Create a file at the backup directory path
|
||||
backupPath := backupDir()
|
||||
// Clean up any existing directory first
|
||||
os.RemoveAll(backupPath)
|
||||
// Create a file instead of directory
|
||||
os.WriteFile(backupPath, []byte("not a directory"), 0o644)
|
||||
defer func() {
|
||||
os.Remove(backupPath)
|
||||
os.MkdirAll(backupPath, 0o755)
|
||||
}()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "test.json")
|
||||
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||
|
||||
err := writeWithBackup(path, []byte(`{"updated": true}`))
|
||||
if err == nil {
|
||||
t.Error("expected error when backup dir is a file, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_NoOrphanTempFiles verifies temp files are cleaned up on failure.
|
||||
func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Count existing temp files
|
||||
countTempFiles := func() int {
|
||||
entries, _ := os.ReadDir(tmpDir)
|
||||
count := 0
|
||||
for _, e := range entries {
|
||||
if len(e.Name()) > 4 && e.Name()[:4] == ".tmp" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
before := countTempFiles()
|
||||
|
||||
// Create a file, then make directory read-only to cause rename failure
|
||||
path := filepath.Join(tmpDir, "orphan.json")
|
||||
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
|
||||
|
||||
// Make a subdirectory and try to write there after making parent read-only
|
||||
subDir := filepath.Join(tmpDir, "subdir")
|
||||
os.MkdirAll(subDir, 0o755)
|
||||
subPath := filepath.Join(subDir, "config.json")
|
||||
os.WriteFile(subPath, []byte(`{"v": 1}`), 0o644)
|
||||
|
||||
// Make subdir read-only after creating temp file would succeed but rename would fail
|
||||
// This is tricky to test - the temp file is created in the same dir, so if we can't
|
||||
// rename, we also couldn't create. Let's just verify normal failure cleanup works.
|
||||
|
||||
// Force a failure by making the target a directory
|
||||
badPath := filepath.Join(tmpDir, "isdir")
|
||||
os.MkdirAll(badPath, 0o755)
|
||||
|
||||
_ = writeWithBackup(badPath, []byte(`{"test": true}`))
|
||||
|
||||
after := countTempFiles()
|
||||
if after > before {
|
||||
t.Errorf("orphan temp files left behind: before=%d, after=%d", before, after)
|
||||
}
|
||||
}
|
||||
353
cmd/config/integrations.go
Normal file
353
cmd/config/integrations.go
Normal file
@@ -0,0 +1,353 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// Runners execute the launching of a model with the integration - claude, codex
|
||||
// Editors can edit config files (supports multi-model selection) - opencode, droid
|
||||
// They are composable interfaces where in some cases an editor is also a runner - opencode, droid
|
||||
// Runner can run an integration with a model.
|
||||
|
||||
type Runner interface {
|
||||
Run(model string) error
|
||||
// String returns the human-readable name of the integration
|
||||
String() string
|
||||
}
|
||||
|
||||
// Editor can edit config files (supports multi-model selection)
|
||||
type Editor interface {
|
||||
// Paths returns the paths to the config files for the integration
|
||||
Paths() []string
|
||||
// Edit updates the config files for the integration with the given models
|
||||
Edit(models []string) error
|
||||
// Models returns the models currently configured for the integration
|
||||
// TODO(parthsareen): add error return to Models()
|
||||
Models() []string
|
||||
}
|
||||
|
||||
// integrations is the registry of available integrations.
|
||||
var integrations = map[string]Runner{
|
||||
"claude": &Claude{},
|
||||
"codex": &Codex{},
|
||||
"droid": &Droid{},
|
||||
"opencode": &OpenCode{},
|
||||
}
|
||||
|
||||
func selectIntegration() (string, error) {
|
||||
if len(integrations) == 0 {
|
||||
return "", fmt.Errorf("no integrations available")
|
||||
}
|
||||
|
||||
names := slices.Sorted(maps.Keys(integrations))
|
||||
var items []selectItem
|
||||
for _, name := range names {
|
||||
r := integrations[name]
|
||||
description := r.String()
|
||||
if conn, err := loadIntegration(name); err == nil && len(conn.Models) > 0 {
|
||||
description = fmt.Sprintf("%s (%s)", r.String(), conn.Models[0])
|
||||
}
|
||||
items = append(items, selectItem{Name: name, Description: description})
|
||||
}
|
||||
|
||||
return selectPrompt("Select integration:", items)
|
||||
}
|
||||
|
||||
// selectModels lets the user select models for an integration
|
||||
func selectModels(ctx context.Context, name, current string) ([]string, error) {
|
||||
r, ok := integrations[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
models, err := client.List(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(models.Models) == 0 {
|
||||
return nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
|
||||
var items []selectItem
|
||||
cloudModels := make(map[string]bool)
|
||||
for _, m := range models.Models {
|
||||
if m.RemoteModel != "" {
|
||||
cloudModels[m.Name] = true
|
||||
}
|
||||
items = append(items, selectItem{Name: m.Name})
|
||||
}
|
||||
|
||||
if len(items) == 0 {
|
||||
return nil, fmt.Errorf("no local models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
|
||||
// Get previously configured models (saved config takes precedence)
|
||||
var preChecked []string
|
||||
if saved, err := loadIntegration(name); err == nil {
|
||||
preChecked = saved.Models
|
||||
} else if editor, ok := r.(Editor); ok {
|
||||
preChecked = editor.Models()
|
||||
}
|
||||
checked := make(map[string]bool, len(preChecked))
|
||||
for _, n := range preChecked {
|
||||
checked[n] = true
|
||||
}
|
||||
|
||||
// Resolve current to full name (e.g., "llama3.2" -> "llama3.2:latest")
|
||||
for _, item := range items {
|
||||
if item.Name == current || strings.HasPrefix(item.Name, current+":") {
|
||||
current = item.Name
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If current model is configured, move to front of preChecked
|
||||
if checked[current] {
|
||||
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
|
||||
}
|
||||
|
||||
// Sort: checked first, then alphabetical
|
||||
slices.SortFunc(items, func(a, b selectItem) int {
|
||||
ac, bc := checked[a.Name], checked[b.Name]
|
||||
if ac != bc {
|
||||
if ac {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
|
||||
})
|
||||
|
||||
var selected []string
|
||||
// only editors support multi-model selection
|
||||
if _, ok := r.(Editor); ok {
|
||||
selected, err = multiSelectPrompt(fmt.Sprintf("Select models for %s:", r), items, preChecked)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
model, err := selectPrompt(fmt.Sprintf("Select model for %s:", r), items)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selected = []string{model}
|
||||
}
|
||||
|
||||
// if any model in selected is a cloud model, ensure signed in
|
||||
var selectedCloudModels []string
|
||||
for _, m := range selected {
|
||||
if cloudModels[m] {
|
||||
selectedCloudModels = append(selectedCloudModels, m)
|
||||
}
|
||||
}
|
||||
if len(selectedCloudModels) > 0 {
|
||||
// ensure user is signed in
|
||||
user, err := client.Whoami(ctx)
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
var aErr api.AuthorizationError
|
||||
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modelList := strings.Join(selectedCloudModels, ", ")
|
||||
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
|
||||
if err != nil || !yes {
|
||||
return nil, fmt.Errorf("%s requires sign in", modelList)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
|
||||
|
||||
// TODO(parthsareen): extract into auth package for cmd
|
||||
// Auto-open browser (best effort, fail silently)
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
_ = exec.Command("open", aErr.SigninURL).Start()
|
||||
case "linux":
|
||||
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
|
||||
case "windows":
|
||||
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
|
||||
}
|
||||
|
||||
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||
frame := 0
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
|
||||
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
return nil, ctx.Err()
|
||||
case <-ticker.C:
|
||||
frame++
|
||||
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||
|
||||
// poll every 10th frame (~2 seconds)
|
||||
if frame%10 == 0 {
|
||||
u, err := client.Whoami(ctx)
|
||||
if err == nil && u != nil && u.Name != "" {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
|
||||
return selected, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func runIntegration(name, modelName string) error {
|
||||
r, ok := integrations[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName)
|
||||
return r.Run(modelName)
|
||||
}
|
||||
|
||||
// LaunchCmd returns the cobra command for launching integrations.
|
||||
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error) *cobra.Command {
|
||||
var modelFlag string
|
||||
var configFlag bool
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "launch [INTEGRATION]",
|
||||
Short: "Launch an integration with Ollama",
|
||||
Long: `Launch an integration configured with Ollama models.
|
||||
|
||||
Supported integrations:
|
||||
claude Claude Code
|
||||
codex Codex
|
||||
droid Droid
|
||||
opencode OpenCode
|
||||
|
||||
Examples:
|
||||
ollama launch
|
||||
ollama launch claude
|
||||
ollama launch claude --model <model>
|
||||
ollama launch droid --config (does not auto-launch)`,
|
||||
Args: cobra.MaximumNArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
var name string
|
||||
if len(args) > 0 {
|
||||
name = args[0]
|
||||
} else {
|
||||
var err error
|
||||
name, err = selectIntegration()
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
r, ok := integrations[strings.ToLower(name)]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
|
||||
// If launching without --model, use saved config if available
|
||||
if !configFlag && modelFlag == "" {
|
||||
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
|
||||
return runIntegration(name, config.Models[0])
|
||||
}
|
||||
}
|
||||
|
||||
var models []string
|
||||
if modelFlag != "" {
|
||||
// When --model is specified, merge with existing models (new model becomes default)
|
||||
models = []string{modelFlag}
|
||||
if existing, err := loadIntegration(name); err == nil && len(existing.Models) > 0 {
|
||||
for _, m := range existing.Models {
|
||||
if m != modelFlag {
|
||||
models = append(models, m)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
var err error
|
||||
models, err = selectModels(cmd.Context(), name, "")
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if editor, isEditor := r.(Editor); isEditor {
|
||||
paths := editor.Paths()
|
||||
if len(paths) > 0 {
|
||||
fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", r)
|
||||
for _, p := range paths {
|
||||
fmt.Fprintf(os.Stderr, " %s\n", p)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", backupDir())
|
||||
|
||||
if ok, _ := confirmPrompt("Proceed?"); !ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := saveIntegration(name, models); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
|
||||
if editor, isEditor := r.(Editor); isEditor {
|
||||
if err := editor.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if _, isEditor := r.(Editor); isEditor {
|
||||
if len(models) == 1 {
|
||||
fmt.Fprintf(os.Stderr, "Added %s to %s\n", models[0], r)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Added %d models to %s (default: %s)\n", len(models), r, models[0])
|
||||
}
|
||||
}
|
||||
|
||||
if configFlag {
|
||||
if launch, _ := confirmPrompt(fmt.Sprintf("\nLaunch %s now?", r)); launch {
|
||||
return runIntegration(name, models[0])
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Run 'ollama launch %s' to start with %s\n", strings.ToLower(name), models[0])
|
||||
return nil
|
||||
}
|
||||
|
||||
return runIntegration(name, models[0])
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&modelFlag, "model", "", "Model to use")
|
||||
cmd.Flags().BoolVar(&configFlag, "config", false, "Configure without launching")
|
||||
return cmd
|
||||
}
|
||||
188
cmd/config/integrations_test.go
Normal file
188
cmd/config/integrations_test.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func TestIntegrationLookup(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantFound bool
|
||||
wantName string
|
||||
}{
|
||||
{"claude lowercase", "claude", true, "Claude Code"},
|
||||
{"claude uppercase", "CLAUDE", true, "Claude Code"},
|
||||
{"claude mixed case", "Claude", true, "Claude Code"},
|
||||
{"codex", "codex", true, "Codex"},
|
||||
{"droid", "droid", true, "Droid"},
|
||||
{"opencode", "opencode", true, "OpenCode"},
|
||||
{"unknown integration", "unknown", false, ""},
|
||||
{"empty string", "", false, ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r, found := integrations[strings.ToLower(tt.input)]
|
||||
if found != tt.wantFound {
|
||||
t.Errorf("integrations[%q] found = %v, want %v", tt.input, found, tt.wantFound)
|
||||
}
|
||||
if found && r.String() != tt.wantName {
|
||||
t.Errorf("integrations[%q].String() = %q, want %q", tt.input, r.String(), tt.wantName)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegrationRegistry(t *testing.T) {
|
||||
expectedIntegrations := []string{"claude", "codex", "droid", "opencode"}
|
||||
|
||||
for _, name := range expectedIntegrations {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
r, ok := integrations[name]
|
||||
if !ok {
|
||||
t.Fatalf("integration %q not found in registry", name)
|
||||
}
|
||||
if r.String() == "" {
|
||||
t.Error("integration.String() should not be empty")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasLocalModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
models []string
|
||||
want bool
|
||||
}{
|
||||
{"empty list", []string{}, false},
|
||||
{"single local model", []string{"llama3.2"}, true},
|
||||
{"single cloud model", []string{"cloud-model"}, false},
|
||||
{"mixed models", []string{"cloud-model", "llama3.2"}, true},
|
||||
{"multiple local models", []string{"llama3.2", "qwen2.5"}, true},
|
||||
{"multiple cloud models", []string{"cloud-a", "cloud-b"}, false},
|
||||
{"local model first", []string{"llama3.2", "cloud-model"}, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := slices.ContainsFunc(tt.models, func(m string) bool {
|
||||
return !strings.Contains(m, "cloud")
|
||||
})
|
||||
if got != tt.want {
|
||||
t.Errorf("hasLocalModel(%v) = %v, want %v", tt.models, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmd(t *testing.T) {
|
||||
// Mock checkServerHeartbeat that always succeeds
|
||||
mockCheck := func(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck)
|
||||
|
||||
t.Run("command structure", func(t *testing.T) {
|
||||
if cmd.Use != "launch [INTEGRATION]" {
|
||||
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION]")
|
||||
}
|
||||
if cmd.Short == "" {
|
||||
t.Error("Short description should not be empty")
|
||||
}
|
||||
if cmd.Long == "" {
|
||||
t.Error("Long description should not be empty")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("flags exist", func(t *testing.T) {
|
||||
modelFlag := cmd.Flags().Lookup("model")
|
||||
if modelFlag == nil {
|
||||
t.Error("--model flag should exist")
|
||||
}
|
||||
|
||||
configFlag := cmd.Flags().Lookup("config")
|
||||
if configFlag == nil {
|
||||
t.Error("--config flag should exist")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PreRunE is set", func(t *testing.T) {
|
||||
if cmd.PreRunE == nil {
|
||||
t.Error("PreRunE should be set to checkServerHeartbeat")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRunIntegration_UnknownIntegration(t *testing.T) {
|
||||
err := runIntegration("unknown-integration", "model")
|
||||
if err == nil {
|
||||
t.Error("expected error for unknown integration, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unknown integration") {
|
||||
t.Errorf("error should mention 'unknown integration', got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasLocalModel_DocumentsHeuristic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
models []string
|
||||
want bool
|
||||
reason string
|
||||
}{
|
||||
{"empty list", []string{}, false, "empty list has no local models"},
|
||||
{"contains-cloud-substring", []string{"deepseek-r1:cloud"}, false, "model with 'cloud' substring is considered cloud"},
|
||||
{"cloud-in-name", []string{"my-cloud-model"}, false, "'cloud' anywhere in name = cloud model"},
|
||||
{"cloudless", []string{"cloudless-model"}, false, "'cloudless' still contains 'cloud'"},
|
||||
{"local-model", []string{"llama3.2"}, true, "no 'cloud' = local"},
|
||||
{"mixed", []string{"cloud-model", "llama3.2"}, true, "one local model = hasLocalModel true"},
|
||||
{"all-cloud", []string{"cloud-a", "cloud-b"}, false, "all contain 'cloud'"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := slices.ContainsFunc(tt.models, func(m string) bool {
|
||||
return !strings.Contains(m, "cloud")
|
||||
})
|
||||
if got != tt.want {
|
||||
t.Errorf("hasLocalModel(%v) = %v, want %v (%s)", tt.models, got, tt.want, tt.reason)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmd_NilHeartbeat(t *testing.T) {
|
||||
// This should not panic - cmd creation should work even with nil
|
||||
cmd := LaunchCmd(nil)
|
||||
if cmd == nil {
|
||||
t.Fatal("LaunchCmd returned nil")
|
||||
}
|
||||
|
||||
// PreRunE should be nil when passed nil
|
||||
if cmd.PreRunE != nil {
|
||||
t.Log("Note: PreRunE is set even when nil is passed (acceptable)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllIntegrations_HaveRequiredMethods(t *testing.T) {
|
||||
for name, r := range integrations {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
// Test String() doesn't panic and returns non-empty
|
||||
displayName := r.String()
|
||||
if displayName == "" {
|
||||
t.Error("String() should not return empty")
|
||||
}
|
||||
|
||||
// Test Run() exists (we can't call it without actually running the command)
|
||||
// Just verify the method is available
|
||||
var _ func(string) error = r.Run
|
||||
})
|
||||
}
|
||||
}
|
||||
224
cmd/config/opencode.go
Normal file
224
cmd/config/opencode.go
Normal file
@@ -0,0 +1,224 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// OpenCode implements Runner and Editor for OpenCode integration
|
||||
type OpenCode struct{}
|
||||
|
||||
func (o *OpenCode) String() string { return "OpenCode" }
|
||||
|
||||
func (o *OpenCode) Run(model string) error {
|
||||
if _, err := exec.LookPath("opencode"); err != nil {
|
||||
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
|
||||
}
|
||||
|
||||
// Call Edit() to ensure config is up-to-date before launch
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("opencode"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
if err := o.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("opencode")
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func (o *OpenCode) Paths() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var paths []string
|
||||
p := filepath.Join(home, ".config", "opencode", "opencode.json")
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
paths = append(paths, p)
|
||||
}
|
||||
sp := filepath.Join(home, ".local", "state", "opencode", "model.json")
|
||||
if _, err := os.Stat(sp); err == nil {
|
||||
paths = append(paths, sp)
|
||||
}
|
||||
return paths
|
||||
}
|
||||
|
||||
func (o *OpenCode) Edit(modelList []string) error {
|
||||
if len(modelList) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".config", "opencode", "opencode.json")
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config := make(map[string]any)
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config) // Ignore parse errors; treat missing/corrupt files as empty
|
||||
}
|
||||
|
||||
config["$schema"] = "https://opencode.ai/config.json"
|
||||
|
||||
provider, ok := config["provider"].(map[string]any)
|
||||
if !ok {
|
||||
provider = make(map[string]any)
|
||||
}
|
||||
|
||||
ollama, ok := provider["ollama"].(map[string]any)
|
||||
if !ok {
|
||||
ollama = map[string]any{
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama (local)",
|
||||
"options": map[string]any{
|
||||
"baseURL": "http://localhost:11434/v1",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
models, ok := ollama["models"].(map[string]any)
|
||||
if !ok {
|
||||
models = make(map[string]any)
|
||||
}
|
||||
|
||||
selectedSet := make(map[string]bool)
|
||||
for _, m := range modelList {
|
||||
selectedSet[m] = true
|
||||
}
|
||||
|
||||
for name, cfg := range models {
|
||||
if cfgMap, ok := cfg.(map[string]any); ok {
|
||||
if isOllamaModel(cfgMap) && !selectedSet[name] {
|
||||
delete(models, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, model := range modelList {
|
||||
if existing, ok := models[model].(map[string]any); ok {
|
||||
// migrate existing models without _launch marker
|
||||
if isOllamaModel(existing) {
|
||||
existing["_launch"] = true
|
||||
if name, ok := existing["name"].(string); ok {
|
||||
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
models[model] = map[string]any{
|
||||
"name": model,
|
||||
"_launch": true,
|
||||
}
|
||||
}
|
||||
|
||||
ollama["models"] = models
|
||||
provider["ollama"] = ollama
|
||||
config["provider"] = provider
|
||||
|
||||
configData, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeWithBackup(configPath, configData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
statePath := filepath.Join(home, ".local", "state", "opencode", "model.json")
|
||||
if err := os.MkdirAll(filepath.Dir(statePath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
state := map[string]any{
|
||||
"recent": []any{},
|
||||
"favorite": []any{},
|
||||
"variant": map[string]any{},
|
||||
}
|
||||
if data, err := os.ReadFile(statePath); err == nil {
|
||||
_ = json.Unmarshal(data, &state) // Ignore parse errors; use defaults
|
||||
}
|
||||
|
||||
recent, _ := state["recent"].([]any)
|
||||
|
||||
modelSet := make(map[string]bool)
|
||||
for _, m := range modelList {
|
||||
modelSet[m] = true
|
||||
}
|
||||
|
||||
// Filter out existing Ollama models we're about to re-add
|
||||
newRecent := slices.DeleteFunc(slices.Clone(recent), func(entry any) bool {
|
||||
e, ok := entry.(map[string]any)
|
||||
if !ok || e["providerID"] != "ollama" {
|
||||
return false
|
||||
}
|
||||
modelID, _ := e["modelID"].(string)
|
||||
return modelSet[modelID]
|
||||
})
|
||||
|
||||
// Prepend models in reverse order so first model ends up first
|
||||
for _, model := range slices.Backward(modelList) {
|
||||
newRecent = slices.Insert(newRecent, 0, any(map[string]any{
|
||||
"providerID": "ollama",
|
||||
"modelID": model,
|
||||
}))
|
||||
}
|
||||
|
||||
const maxRecentModels = 10
|
||||
newRecent = newRecent[:min(len(newRecent), maxRecentModels)]
|
||||
|
||||
state["recent"] = newRecent
|
||||
|
||||
stateData, err := json.MarshalIndent(state, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(statePath, stateData)
|
||||
}
|
||||
|
||||
func (o *OpenCode) Models() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
config, err := readJSONFile(filepath.Join(home, ".config", "opencode", "opencode.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
provider, _ := config["provider"].(map[string]any)
|
||||
ollama, _ := provider["ollama"].(map[string]any)
|
||||
models, _ := ollama["models"].(map[string]any)
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
keys := slices.Collect(maps.Keys(models))
|
||||
slices.Sort(keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
// isOllamaModel reports whether a model config entry is managed by us
|
||||
func isOllamaModel(cfg map[string]any) bool {
|
||||
if v, ok := cfg["_launch"].(bool); ok && v {
|
||||
return true
|
||||
}
|
||||
// previously used [Ollama] as a suffix for the model managed by ollama launch
|
||||
if name, ok := cfg["name"].(string); ok {
|
||||
return strings.HasSuffix(name, "[Ollama]")
|
||||
}
|
||||
return false
|
||||
}
|
||||
507
cmd/config/opencode_test.go
Normal file
507
cmd/config/opencode_test.go
Normal file
@@ -0,0 +1,507 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestOpenCodeIntegration(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := o.String(); got != "OpenCode" {
|
||||
t.Errorf("String() = %q, want %q", got, "OpenCode")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = o
|
||||
})
|
||||
|
||||
t.Run("implements Editor", func(t *testing.T) {
|
||||
var _ Editor = o
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||
statePath := filepath.Join(stateDir, "model.json")
|
||||
|
||||
cleanup := func() {
|
||||
os.RemoveAll(configDir)
|
||||
os.RemoveAll(stateDir)
|
||||
}
|
||||
|
||||
t.Run("fresh install", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
|
||||
})
|
||||
|
||||
t.Run("preserve other providers", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"provider":{"anthropic":{"apiKey":"xxx"}}}`), 0o644)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider := cfg["provider"].(map[string]any)
|
||||
if provider["anthropic"] == nil {
|
||||
t.Error("anthropic provider was removed")
|
||||
}
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
})
|
||||
|
||||
t.Run("preserve other models", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"mistral":{"name":"Mistral"}}}}}`), 0o644)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertOpenCodeModelExists(t, configPath, "mistral")
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
})
|
||||
|
||||
t.Run("update existing model", func(t *testing.T) {
|
||||
cleanup()
|
||||
o.Edit([]string{"llama3.2"})
|
||||
o.Edit([]string{"llama3.2"})
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
})
|
||||
|
||||
t.Run("preserve top-level keys", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"theme":"dark","keybindings":{}}`), 0o644)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
if cfg["theme"] != "dark" {
|
||||
t.Error("theme was removed")
|
||||
}
|
||||
if cfg["keybindings"] == nil {
|
||||
t.Error("keybindings was removed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("model state - insert at index 0", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(statePath, []byte(`{"recent":[{"providerID":"anthropic","modelID":"claude"}],"favorite":[],"variant":{}}`), 0o644)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
|
||||
assertOpenCodeRecentModel(t, statePath, 1, "anthropic", "claude")
|
||||
})
|
||||
|
||||
t.Run("model state - preserve favorites and variants", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(statePath, []byte(`{"recent":[],"favorite":[{"providerID":"x","modelID":"y"}],"variant":{"a":"b"}}`), 0o644)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _ := os.ReadFile(statePath)
|
||||
var state map[string]any
|
||||
json.Unmarshal(data, &state)
|
||||
if len(state["favorite"].([]any)) != 1 {
|
||||
t.Error("favorite was modified")
|
||||
}
|
||||
if state["variant"].(map[string]any)["a"] != "b" {
|
||||
t.Error("variant was modified")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("model state - deduplicate on re-add", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(statePath, []byte(`{"recent":[{"providerID":"ollama","modelID":"llama3.2"},{"providerID":"anthropic","modelID":"claude"}],"favorite":[],"variant":{}}`), 0o644)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _ := os.ReadFile(statePath)
|
||||
var state map[string]any
|
||||
json.Unmarshal(data, &state)
|
||||
recent := state["recent"].([]any)
|
||||
if len(recent) != 2 {
|
||||
t.Errorf("expected 2 recent entries, got %d", len(recent))
|
||||
}
|
||||
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
|
||||
})
|
||||
|
||||
t.Run("remove model", func(t *testing.T) {
|
||||
cleanup()
|
||||
// First add two models
|
||||
o.Edit([]string{"llama3.2", "mistral"})
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
assertOpenCodeModelExists(t, configPath, "mistral")
|
||||
|
||||
// Then remove one by only selecting the other
|
||||
o.Edit([]string{"llama3.2"})
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
assertOpenCodeModelNotExists(t, configPath, "mistral")
|
||||
})
|
||||
|
||||
t.Run("preserve user customizations on managed models", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Add custom fields to the model entry (simulating user edits)
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider := cfg["provider"].(map[string]any)
|
||||
ollama := provider["ollama"].(map[string]any)
|
||||
models := ollama["models"].(map[string]any)
|
||||
entry := models["llama3.2"].(map[string]any)
|
||||
entry["_myPref"] = "custom-value"
|
||||
entry["_myNum"] = 42
|
||||
configData, _ := json.MarshalIndent(cfg, "", " ")
|
||||
os.WriteFile(configPath, configData, 0o644)
|
||||
|
||||
// Re-run Edit — should preserve custom fields
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ = os.ReadFile(configPath)
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider = cfg["provider"].(map[string]any)
|
||||
ollama = provider["ollama"].(map[string]any)
|
||||
models = ollama["models"].(map[string]any)
|
||||
entry = models["llama3.2"].(map[string]any)
|
||||
|
||||
if entry["_myPref"] != "custom-value" {
|
||||
t.Errorf("_myPref was lost: got %v", entry["_myPref"])
|
||||
}
|
||||
if entry["_myNum"] != float64(42) {
|
||||
t.Errorf("_myNum was lost: got %v", entry["_myNum"])
|
||||
}
|
||||
if v, ok := entry["_launch"].(bool); !ok || !v {
|
||||
t.Errorf("_launch marker missing or false: got %v", entry["_launch"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("migrate legacy [Ollama] suffix entries", func(t *testing.T) {
|
||||
cleanup()
|
||||
// Write a config with a legacy entry (has [Ollama] suffix but no _launch marker)
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"llama3.2":{"name":"llama3.2 [Ollama]"}}}}}`), 0o644)
|
||||
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider := cfg["provider"].(map[string]any)
|
||||
ollama := provider["ollama"].(map[string]any)
|
||||
models := ollama["models"].(map[string]any)
|
||||
entry := models["llama3.2"].(map[string]any)
|
||||
|
||||
// _launch marker should be added
|
||||
if v, ok := entry["_launch"].(bool); !ok || !v {
|
||||
t.Errorf("_launch marker not added during migration: got %v", entry["_launch"])
|
||||
}
|
||||
// [Ollama] suffix should be stripped
|
||||
if name, ok := entry["name"].(string); !ok || name != "llama3.2" {
|
||||
t.Errorf("name suffix not stripped: got %q", entry["name"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("remove model preserves non-ollama models", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
// Add a non-Ollama model manually
|
||||
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"external":{"name":"External Model"}}}}}`), 0o644)
|
||||
|
||||
o.Edit([]string{"llama3.2"})
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
assertOpenCodeModelExists(t, configPath, "external") // Should be preserved
|
||||
})
|
||||
}
|
||||
|
||||
func assertOpenCodeModelExists(t *testing.T, path, model string) {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
provider, ok := cfg["provider"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("provider not found")
|
||||
}
|
||||
ollama, ok := provider["ollama"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("ollama provider not found")
|
||||
}
|
||||
models, ok := ollama["models"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("models not found")
|
||||
}
|
||||
if models[model] == nil {
|
||||
t.Errorf("model %s not found", model)
|
||||
}
|
||||
}
|
||||
|
||||
func assertOpenCodeModelNotExists(t *testing.T, path, model string) {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
provider, ok := cfg["provider"].(map[string]any)
|
||||
if !ok {
|
||||
return // No provider means no model
|
||||
}
|
||||
ollama, ok := provider["ollama"].(map[string]any)
|
||||
if !ok {
|
||||
return // No ollama means no model
|
||||
}
|
||||
models, ok := ollama["models"].(map[string]any)
|
||||
if !ok {
|
||||
return // No models means no model
|
||||
}
|
||||
if models[model] != nil {
|
||||
t.Errorf("model %s should not exist but was found", model)
|
||||
}
|
||||
}
|
||||
|
||||
func assertOpenCodeRecentModel(t *testing.T, path string, index int, providerID, modelID string) {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var state map[string]any
|
||||
if err := json.Unmarshal(data, &state); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
recent, ok := state["recent"].([]any)
|
||||
if !ok {
|
||||
t.Fatal("recent not found")
|
||||
}
|
||||
if index >= len(recent) {
|
||||
t.Fatalf("index %d out of range (len=%d)", index, len(recent))
|
||||
}
|
||||
entry, ok := recent[index].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("entry is not a map")
|
||||
}
|
||||
if entry["providerID"] != providerID {
|
||||
t.Errorf("expected providerID %s, got %s", providerID, entry["providerID"])
|
||||
}
|
||||
if entry["modelID"] != modelID {
|
||||
t.Errorf("expected modelID %s, got %s", modelID, entry["modelID"])
|
||||
}
|
||||
}
|
||||
|
||||
// Edge case tests for opencode.go
|
||||
|
||||
func TestOpenCodeEdit_CorruptedConfigJSON(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{corrupted json content`), 0o644)
|
||||
|
||||
// Should not panic - corrupted JSON should be treated as empty
|
||||
err := o.Edit([]string{"llama3.2"})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit failed with corrupted config: %v", err)
|
||||
}
|
||||
|
||||
// Verify valid JSON was created
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Errorf("resulting config is not valid JSON: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_CorruptedStateJSON(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||
statePath := filepath.Join(stateDir, "model.json")
|
||||
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(statePath, []byte(`{corrupted state`), 0o644)
|
||||
|
||||
err := o.Edit([]string{"llama3.2"})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit failed with corrupted state: %v", err)
|
||||
}
|
||||
|
||||
// Verify valid state was created
|
||||
data, _ := os.ReadFile(statePath)
|
||||
var state map[string]any
|
||||
if err := json.Unmarshal(data, &state); err != nil {
|
||||
t.Errorf("resulting state is not valid JSON: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_WrongTypeProvider(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"provider": "not a map"}`), 0o644)
|
||||
|
||||
err := o.Edit([]string{"llama3.2"})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit with wrong type provider failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify provider is now correct type
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
|
||||
provider, ok := cfg["provider"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("provider should be map after setup, got %T", cfg["provider"])
|
||||
}
|
||||
if provider["ollama"] == nil {
|
||||
t.Error("ollama provider should be created")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_WrongTypeRecent(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||
statePath := filepath.Join(stateDir, "model.json")
|
||||
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(statePath, []byte(`{"recent": "not an array", "favorite": [], "variant": {}}`), 0o644)
|
||||
|
||||
err := o.Edit([]string{"llama3.2"})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit with wrong type recent failed: %v", err)
|
||||
}
|
||||
|
||||
// The function should handle this gracefully
|
||||
data, _ := os.ReadFile(statePath)
|
||||
var state map[string]any
|
||||
json.Unmarshal(data, &state)
|
||||
|
||||
// recent should be properly set after setup
|
||||
recent, ok := state["recent"].([]any)
|
||||
if !ok {
|
||||
t.Logf("Note: recent type after setup is %T (documenting behavior)", state["recent"])
|
||||
} else if len(recent) == 0 {
|
||||
t.Logf("Note: recent is empty (documenting behavior)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_EmptyModels(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
originalContent := `{"provider":{"ollama":{"models":{"existing":{}}}}}`
|
||||
os.WriteFile(configPath, []byte(originalContent), 0o644)
|
||||
|
||||
// Empty models should be no-op
|
||||
err := o.Edit([]string{})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit with empty models failed: %v", err)
|
||||
}
|
||||
|
||||
// Original content should be preserved (file not modified)
|
||||
data, _ := os.ReadFile(configPath)
|
||||
if string(data) != originalContent {
|
||||
t.Errorf("empty models should not modify file, but content changed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_SpecialCharsInModelName(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Model name with special characters (though unusual)
|
||||
specialModel := `model-with-"quotes"`
|
||||
|
||||
err := o.Edit([]string{specialModel})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit with special chars failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify it was stored correctly
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
data, _ := os.ReadFile(configPath)
|
||||
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Fatalf("resulting config is invalid JSON: %v", err)
|
||||
}
|
||||
|
||||
// Model should be accessible
|
||||
provider, _ := cfg["provider"].(map[string]any)
|
||||
ollama, _ := provider["ollama"].(map[string]any)
|
||||
models, _ := ollama["models"].(map[string]any)
|
||||
|
||||
if models[specialModel] == nil {
|
||||
t.Errorf("model with special chars not found in config")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeModels_NoConfig(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
models := o.Models()
|
||||
if len(models) > 0 {
|
||||
t.Errorf("expected nil/empty for missing config, got %v", models)
|
||||
}
|
||||
}
|
||||
499
cmd/config/selector.go
Normal file
499
cmd/config/selector.go
Normal file
@@ -0,0 +1,499 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// ANSI escape sequences for terminal formatting.
|
||||
const (
|
||||
ansiHideCursor = "\033[?25l"
|
||||
ansiShowCursor = "\033[?25h"
|
||||
ansiBold = "\033[1m"
|
||||
ansiReset = "\033[0m"
|
||||
ansiGray = "\033[37m"
|
||||
ansiClearDown = "\033[J"
|
||||
)
|
||||
|
||||
const maxDisplayedItems = 10
|
||||
|
||||
var errCancelled = errors.New("cancelled")
|
||||
|
||||
type selectItem struct {
|
||||
Name string
|
||||
Description string
|
||||
}
|
||||
|
||||
type inputEvent int
|
||||
|
||||
const (
|
||||
eventNone inputEvent = iota
|
||||
eventEnter
|
||||
eventEscape
|
||||
eventUp
|
||||
eventDown
|
||||
eventTab
|
||||
eventBackspace
|
||||
eventChar
|
||||
)
|
||||
|
||||
type selectState struct {
|
||||
items []selectItem
|
||||
filter string
|
||||
selected int
|
||||
scrollOffset int
|
||||
}
|
||||
|
||||
func newSelectState(items []selectItem) *selectState {
|
||||
return &selectState{items: items}
|
||||
}
|
||||
|
||||
func (s *selectState) filtered() []selectItem {
|
||||
return filterItems(s.items, s.filter)
|
||||
}
|
||||
|
||||
func (s *selectState) handleInput(event inputEvent, char byte) (done bool, result string, err error) {
|
||||
filtered := s.filtered()
|
||||
|
||||
switch event {
|
||||
case eventEnter:
|
||||
if len(filtered) > 0 && s.selected < len(filtered) {
|
||||
return true, filtered[s.selected].Name, nil
|
||||
}
|
||||
case eventEscape:
|
||||
return true, "", errCancelled
|
||||
case eventBackspace:
|
||||
if len(s.filter) > 0 {
|
||||
s.filter = s.filter[:len(s.filter)-1]
|
||||
s.selected = 0
|
||||
s.scrollOffset = 0
|
||||
}
|
||||
case eventUp:
|
||||
if s.selected > 0 {
|
||||
s.selected--
|
||||
if s.selected < s.scrollOffset {
|
||||
s.scrollOffset = s.selected
|
||||
}
|
||||
}
|
||||
case eventDown:
|
||||
if s.selected < len(filtered)-1 {
|
||||
s.selected++
|
||||
if s.selected >= s.scrollOffset+maxDisplayedItems {
|
||||
s.scrollOffset = s.selected - maxDisplayedItems + 1
|
||||
}
|
||||
}
|
||||
case eventChar:
|
||||
s.filter += string(char)
|
||||
s.selected = 0
|
||||
s.scrollOffset = 0
|
||||
}
|
||||
|
||||
return false, "", nil
|
||||
}
|
||||
|
||||
type multiSelectState struct {
|
||||
items []selectItem
|
||||
itemIndex map[string]int
|
||||
filter string
|
||||
highlighted int
|
||||
scrollOffset int
|
||||
checked map[int]bool
|
||||
checkOrder []int
|
||||
focusOnButton bool
|
||||
}
|
||||
|
||||
func newMultiSelectState(items []selectItem, preChecked []string) *multiSelectState {
|
||||
s := &multiSelectState{
|
||||
items: items,
|
||||
itemIndex: make(map[string]int, len(items)),
|
||||
checked: make(map[int]bool),
|
||||
}
|
||||
|
||||
for i, item := range items {
|
||||
s.itemIndex[item.Name] = i
|
||||
}
|
||||
|
||||
for _, name := range preChecked {
|
||||
if idx, ok := s.itemIndex[name]; ok {
|
||||
s.checked[idx] = true
|
||||
s.checkOrder = append(s.checkOrder, idx)
|
||||
}
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *multiSelectState) filtered() []selectItem {
|
||||
return filterItems(s.items, s.filter)
|
||||
}
|
||||
|
||||
func (s *multiSelectState) toggleItem() {
|
||||
filtered := s.filtered()
|
||||
if len(filtered) == 0 || s.highlighted >= len(filtered) {
|
||||
return
|
||||
}
|
||||
|
||||
item := filtered[s.highlighted]
|
||||
origIdx := s.itemIndex[item.Name]
|
||||
|
||||
if s.checked[origIdx] {
|
||||
delete(s.checked, origIdx)
|
||||
for i, idx := range s.checkOrder {
|
||||
if idx == origIdx {
|
||||
s.checkOrder = append(s.checkOrder[:i], s.checkOrder[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
s.checked[origIdx] = true
|
||||
s.checkOrder = append(s.checkOrder, origIdx)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *multiSelectState) handleInput(event inputEvent, char byte) (done bool, result []string, err error) {
|
||||
filtered := s.filtered()
|
||||
|
||||
switch event {
|
||||
case eventEnter:
|
||||
if s.focusOnButton && len(s.checkOrder) > 0 {
|
||||
var res []string
|
||||
for _, idx := range s.checkOrder {
|
||||
res = append(res, s.items[idx].Name)
|
||||
}
|
||||
return true, res, nil
|
||||
} else if !s.focusOnButton {
|
||||
s.toggleItem()
|
||||
}
|
||||
case eventTab:
|
||||
if len(s.checkOrder) > 0 {
|
||||
s.focusOnButton = !s.focusOnButton
|
||||
}
|
||||
case eventEscape:
|
||||
return true, nil, errCancelled
|
||||
case eventBackspace:
|
||||
if len(s.filter) > 0 {
|
||||
s.filter = s.filter[:len(s.filter)-1]
|
||||
s.highlighted = 0
|
||||
s.scrollOffset = 0
|
||||
s.focusOnButton = false
|
||||
}
|
||||
case eventUp:
|
||||
if s.focusOnButton {
|
||||
s.focusOnButton = false
|
||||
} else if s.highlighted > 0 {
|
||||
s.highlighted--
|
||||
if s.highlighted < s.scrollOffset {
|
||||
s.scrollOffset = s.highlighted
|
||||
}
|
||||
}
|
||||
case eventDown:
|
||||
if s.focusOnButton {
|
||||
s.focusOnButton = false
|
||||
} else if s.highlighted < len(filtered)-1 {
|
||||
s.highlighted++
|
||||
if s.highlighted >= s.scrollOffset+maxDisplayedItems {
|
||||
s.scrollOffset = s.highlighted - maxDisplayedItems + 1
|
||||
}
|
||||
}
|
||||
case eventChar:
|
||||
s.filter += string(char)
|
||||
s.highlighted = 0
|
||||
s.scrollOffset = 0
|
||||
s.focusOnButton = false
|
||||
}
|
||||
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
func (s *multiSelectState) selectedCount() int {
|
||||
return len(s.checkOrder)
|
||||
}
|
||||
|
||||
// Terminal I/O handling
|
||||
|
||||
type terminalState struct {
|
||||
fd int
|
||||
oldState *term.State
|
||||
}
|
||||
|
||||
func enterRawMode() (*terminalState, error) {
|
||||
fd := int(os.Stdin.Fd())
|
||||
oldState, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Fprint(os.Stderr, ansiHideCursor)
|
||||
return &terminalState{fd: fd, oldState: oldState}, nil
|
||||
}
|
||||
|
||||
func (t *terminalState) restore() {
|
||||
fmt.Fprint(os.Stderr, ansiShowCursor)
|
||||
term.Restore(t.fd, t.oldState)
|
||||
}
|
||||
|
||||
func clearLines(n int) {
|
||||
if n > 0 {
|
||||
fmt.Fprintf(os.Stderr, "\033[%dA", n)
|
||||
fmt.Fprint(os.Stderr, ansiClearDown)
|
||||
}
|
||||
}
|
||||
|
||||
func parseInput(r io.Reader) (inputEvent, byte, error) {
|
||||
buf := make([]byte, 3)
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
switch {
|
||||
case n == 1 && buf[0] == 13:
|
||||
return eventEnter, 0, nil
|
||||
case n == 1 && (buf[0] == 3 || buf[0] == 27):
|
||||
return eventEscape, 0, nil
|
||||
case n == 1 && buf[0] == 9:
|
||||
return eventTab, 0, nil
|
||||
case n == 1 && buf[0] == 127:
|
||||
return eventBackspace, 0, nil
|
||||
case n == 3 && buf[0] == 27 && buf[1] == 91 && buf[2] == 65:
|
||||
return eventUp, 0, nil
|
||||
case n == 3 && buf[0] == 27 && buf[1] == 91 && buf[2] == 66:
|
||||
return eventDown, 0, nil
|
||||
case n == 1 && buf[0] >= 32 && buf[0] < 127:
|
||||
return eventChar, buf[0], nil
|
||||
}
|
||||
|
||||
return eventNone, 0, nil
|
||||
}
|
||||
|
||||
// Rendering
|
||||
|
||||
func renderSelect(w io.Writer, prompt string, s *selectState) int {
|
||||
filtered := s.filtered()
|
||||
|
||||
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
|
||||
lineCount := 1
|
||||
|
||||
if len(filtered) == 0 {
|
||||
fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset)
|
||||
lineCount++
|
||||
} else {
|
||||
displayCount := min(len(filtered), maxDisplayedItems)
|
||||
|
||||
for i := range displayCount {
|
||||
idx := s.scrollOffset + i
|
||||
if idx >= len(filtered) {
|
||||
break
|
||||
}
|
||||
item := filtered[idx]
|
||||
prefix := " "
|
||||
if idx == s.selected {
|
||||
prefix = " " + ansiBold + "> "
|
||||
}
|
||||
if item.Description != "" {
|
||||
fmt.Fprintf(w, "%s%s%s %s- %s%s\r\n", prefix, item.Name, ansiReset, ansiGray, item.Description, ansiReset)
|
||||
} else {
|
||||
fmt.Fprintf(w, "%s%s%s\r\n", prefix, item.Name, ansiReset)
|
||||
}
|
||||
lineCount++
|
||||
}
|
||||
|
||||
if remaining := len(filtered) - s.scrollOffset - displayCount; remaining > 0 {
|
||||
fmt.Fprintf(w, " %s... and %d more%s\r\n", ansiGray, remaining, ansiReset)
|
||||
lineCount++
|
||||
}
|
||||
}
|
||||
|
||||
return lineCount
|
||||
}
|
||||
|
||||
func renderMultiSelect(w io.Writer, prompt string, s *multiSelectState) int {
|
||||
filtered := s.filtered()
|
||||
|
||||
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
|
||||
lineCount := 1
|
||||
|
||||
if len(filtered) == 0 {
|
||||
fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset)
|
||||
lineCount++
|
||||
} else {
|
||||
displayCount := min(len(filtered), maxDisplayedItems)
|
||||
|
||||
for i := range displayCount {
|
||||
idx := s.scrollOffset + i
|
||||
if idx >= len(filtered) {
|
||||
break
|
||||
}
|
||||
item := filtered[idx]
|
||||
origIdx := s.itemIndex[item.Name]
|
||||
|
||||
checkbox := "[ ]"
|
||||
if s.checked[origIdx] {
|
||||
checkbox = "[x]"
|
||||
}
|
||||
|
||||
prefix := " "
|
||||
suffix := ""
|
||||
if idx == s.highlighted && !s.focusOnButton {
|
||||
prefix = "> "
|
||||
}
|
||||
if len(s.checkOrder) > 0 && s.checkOrder[0] == origIdx {
|
||||
suffix = " " + ansiGray + "(default)" + ansiReset
|
||||
}
|
||||
|
||||
if idx == s.highlighted && !s.focusOnButton {
|
||||
fmt.Fprintf(w, " %s%s %s %s%s%s\r\n", ansiBold, prefix, checkbox, item.Name, ansiReset, suffix)
|
||||
} else {
|
||||
fmt.Fprintf(w, " %s %s %s%s\r\n", prefix, checkbox, item.Name, suffix)
|
||||
}
|
||||
lineCount++
|
||||
}
|
||||
|
||||
if remaining := len(filtered) - s.scrollOffset - displayCount; remaining > 0 {
|
||||
fmt.Fprintf(w, " %s... and %d more%s\r\n", ansiGray, remaining, ansiReset)
|
||||
lineCount++
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "\r\n")
|
||||
lineCount++
|
||||
count := s.selectedCount()
|
||||
switch {
|
||||
case count == 0:
|
||||
fmt.Fprintf(w, " %sSelect at least one model.%s\r\n", ansiGray, ansiReset)
|
||||
case s.focusOnButton:
|
||||
fmt.Fprintf(w, " %s> [ Continue ]%s %s(%d selected)%s\r\n", ansiBold, ansiReset, ansiGray, count, ansiReset)
|
||||
default:
|
||||
fmt.Fprintf(w, " %s[ Continue ] (%d selected) - press Tab%s\r\n", ansiGray, count, ansiReset)
|
||||
}
|
||||
lineCount++
|
||||
|
||||
return lineCount
|
||||
}
|
||||
|
||||
// selectPrompt prompts the user to select a single item from a list.
|
||||
func selectPrompt(prompt string, items []selectItem) (string, error) {
|
||||
if len(items) == 0 {
|
||||
return "", fmt.Errorf("no items to select from")
|
||||
}
|
||||
|
||||
ts, err := enterRawMode()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer ts.restore()
|
||||
|
||||
state := newSelectState(items)
|
||||
var lastLineCount int
|
||||
|
||||
render := func() {
|
||||
clearLines(lastLineCount)
|
||||
lastLineCount = renderSelect(os.Stderr, prompt, state)
|
||||
}
|
||||
|
||||
render()
|
||||
|
||||
for {
|
||||
event, char, err := parseInput(os.Stdin)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
done, result, err := state.handleInput(event, char)
|
||||
if done {
|
||||
clearLines(lastLineCount)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
render()
|
||||
}
|
||||
}
|
||||
|
||||
// multiSelectPrompt prompts the user to select multiple items from a list.
|
||||
func multiSelectPrompt(prompt string, items []selectItem, preChecked []string) ([]string, error) {
|
||||
if len(items) == 0 {
|
||||
return nil, fmt.Errorf("no items to select from")
|
||||
}
|
||||
|
||||
ts, err := enterRawMode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer ts.restore()
|
||||
|
||||
state := newMultiSelectState(items, preChecked)
|
||||
var lastLineCount int
|
||||
|
||||
render := func() {
|
||||
clearLines(lastLineCount)
|
||||
lastLineCount = renderMultiSelect(os.Stderr, prompt, state)
|
||||
}
|
||||
|
||||
render()
|
||||
|
||||
for {
|
||||
event, char, err := parseInput(os.Stdin)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
done, result, err := state.handleInput(event, char)
|
||||
if done {
|
||||
clearLines(lastLineCount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
render()
|
||||
}
|
||||
}
|
||||
|
||||
func confirmPrompt(prompt string) (bool, error) {
|
||||
fd := int(os.Stdin.Fd())
|
||||
oldState, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer term.Restore(fd, oldState)
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%s (\033[1my\033[0m/n) ", prompt)
|
||||
|
||||
buf := make([]byte, 1)
|
||||
for {
|
||||
if _, err := os.Stdin.Read(buf); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
switch buf[0] {
|
||||
case 'Y', 'y', 13:
|
||||
fmt.Fprintf(os.Stderr, "yes\r\n")
|
||||
return true, nil
|
||||
case 'N', 'n', 27, 3:
|
||||
fmt.Fprintf(os.Stderr, "no\r\n")
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func filterItems(items []selectItem, filter string) []selectItem {
|
||||
if filter == "" {
|
||||
return items
|
||||
}
|
||||
var result []selectItem
|
||||
filterLower := strings.ToLower(filter)
|
||||
for _, item := range items {
|
||||
if strings.Contains(strings.ToLower(item.Name), filterLower) {
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
913
cmd/config/selector_test.go
Normal file
913
cmd/config/selector_test.go
Normal file
@@ -0,0 +1,913 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFilterItems(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "llama3.2:latest"},
|
||||
{Name: "qwen2.5:7b"},
|
||||
{Name: "deepseek-v3:cloud"},
|
||||
{Name: "GPT-OSS:20b"},
|
||||
}
|
||||
|
||||
t.Run("EmptyFilter_ReturnsAllItems", func(t *testing.T) {
|
||||
result := filterItems(items, "")
|
||||
if len(result) != len(items) {
|
||||
t.Errorf("expected %d items, got %d", len(items), len(result))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CaseInsensitive_UppercaseFilterMatchesLowercase", func(t *testing.T) {
|
||||
result := filterItems(items, "LLAMA")
|
||||
if len(result) != 1 || result[0].Name != "llama3.2:latest" {
|
||||
t.Errorf("expected llama3.2:latest, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CaseInsensitive_LowercaseFilterMatchesUppercase", func(t *testing.T) {
|
||||
result := filterItems(items, "gpt")
|
||||
if len(result) != 1 || result[0].Name != "GPT-OSS:20b" {
|
||||
t.Errorf("expected GPT-OSS:20b, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PartialMatch", func(t *testing.T) {
|
||||
result := filterItems(items, "deep")
|
||||
if len(result) != 1 || result[0].Name != "deepseek-v3:cloud" {
|
||||
t.Errorf("expected deepseek-v3:cloud, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NoMatch_ReturnsEmpty", func(t *testing.T) {
|
||||
result := filterItems(items, "nonexistent")
|
||||
if len(result) != 0 {
|
||||
t.Errorf("expected 0 items, got %d", len(result))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSelectState(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "item1"},
|
||||
{Name: "item2"},
|
||||
{Name: "item3"},
|
||||
}
|
||||
|
||||
t.Run("InitialState", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
if s.selected != 0 {
|
||||
t.Errorf("expected selected=0, got %d", s.selected)
|
||||
}
|
||||
if s.filter != "" {
|
||||
t.Errorf("expected empty filter, got %q", s.filter)
|
||||
}
|
||||
if s.scrollOffset != 0 {
|
||||
t.Errorf("expected scrollOffset=0, got %d", s.scrollOffset)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_SelectsCurrentItem", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if !done || result != "item1" || err != nil {
|
||||
t.Errorf("expected (true, item1, nil), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_WithFilter_SelectsFilteredItem", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.filter = "item3"
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if !done || result != "item3" || err != nil {
|
||||
t.Errorf("expected (true, item3, nil), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_EmptyFilteredList_DoesNothing", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.filter = "nonexistent"
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if done || result != "" || err != nil {
|
||||
t.Errorf("expected (false, '', nil), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
done, result, err := s.handleInput(eventEscape, 0)
|
||||
if !done || result != "" || err != errCancelled {
|
||||
t.Errorf("expected (true, '', errCancelled), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Down_MovesSelection", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.selected != 1 {
|
||||
t.Errorf("expected selected=1, got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Down_AtBottom_StaysAtBottom", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.selected = 2
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.selected != 2 {
|
||||
t.Errorf("expected selected=2 (stayed at bottom), got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Up_MovesSelection", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.selected = 2
|
||||
s.handleInput(eventUp, 0)
|
||||
if s.selected != 1 {
|
||||
t.Errorf("expected selected=1, got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Up_AtTop_StaysAtTop", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.handleInput(eventUp, 0)
|
||||
if s.selected != 0 {
|
||||
t.Errorf("expected selected=0 (stayed at top), got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Char_AppendsToFilter", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.handleInput(eventChar, 'i')
|
||||
s.handleInput(eventChar, 't')
|
||||
s.handleInput(eventChar, 'e')
|
||||
s.handleInput(eventChar, 'm')
|
||||
s.handleInput(eventChar, '2')
|
||||
if s.filter != "item2" {
|
||||
t.Errorf("expected filter='item2', got %q", s.filter)
|
||||
}
|
||||
filtered := s.filtered()
|
||||
if len(filtered) != 1 || filtered[0].Name != "item2" {
|
||||
t.Errorf("expected [item2], got %v", filtered)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Char_ResetsSelectionToZero", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.selected = 2
|
||||
s.handleInput(eventChar, 'x')
|
||||
if s.selected != 0 {
|
||||
t.Errorf("expected selected=0 after typing, got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace_RemovesLastFilterChar", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.filter = "test"
|
||||
s.handleInput(eventBackspace, 0)
|
||||
if s.filter != "tes" {
|
||||
t.Errorf("expected filter='tes', got %q", s.filter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace_EmptyFilter_DoesNothing", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.handleInput(eventBackspace, 0)
|
||||
if s.filter != "" {
|
||||
t.Errorf("expected filter='', got %q", s.filter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace_ResetsSelectionToZero", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.filter = "test"
|
||||
s.selected = 2
|
||||
s.handleInput(eventBackspace, 0)
|
||||
if s.selected != 0 {
|
||||
t.Errorf("expected selected=0 after backspace, got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Scroll_DownPastVisibleItems_ScrollsViewport", func(t *testing.T) {
|
||||
// maxDisplayedItems is 10, so with 15 items we need to scroll
|
||||
manyItems := make([]selectItem, 15)
|
||||
for i := range manyItems {
|
||||
manyItems[i] = selectItem{Name: string(rune('a' + i))}
|
||||
}
|
||||
s := newSelectState(manyItems)
|
||||
|
||||
// move down 12 times (past the 10-item viewport)
|
||||
for range 12 {
|
||||
s.handleInput(eventDown, 0)
|
||||
}
|
||||
|
||||
if s.selected != 12 {
|
||||
t.Errorf("expected selected=12, got %d", s.selected)
|
||||
}
|
||||
if s.scrollOffset != 3 {
|
||||
t.Errorf("expected scrollOffset=3 (12-10+1), got %d", s.scrollOffset)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Scroll_UpPastScrollOffset_ScrollsViewport", func(t *testing.T) {
|
||||
manyItems := make([]selectItem, 15)
|
||||
for i := range manyItems {
|
||||
manyItems[i] = selectItem{Name: string(rune('a' + i))}
|
||||
}
|
||||
s := newSelectState(manyItems)
|
||||
s.selected = 5
|
||||
s.scrollOffset = 5
|
||||
|
||||
s.handleInput(eventUp, 0)
|
||||
|
||||
if s.selected != 4 {
|
||||
t.Errorf("expected selected=4, got %d", s.selected)
|
||||
}
|
||||
if s.scrollOffset != 4 {
|
||||
t.Errorf("expected scrollOffset=4, got %d", s.scrollOffset)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMultiSelectState(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "item1"},
|
||||
{Name: "item2"},
|
||||
{Name: "item3"},
|
||||
}
|
||||
|
||||
t.Run("InitialState_NoPrechecked", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
if s.highlighted != 0 {
|
||||
t.Errorf("expected highlighted=0, got %d", s.highlighted)
|
||||
}
|
||||
if s.selectedCount() != 0 {
|
||||
t.Errorf("expected 0 selected, got %d", s.selectedCount())
|
||||
}
|
||||
if s.focusOnButton {
|
||||
t.Error("expected focusOnButton=false initially")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("InitialState_WithPrechecked", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item2", "item3"})
|
||||
if s.selectedCount() != 2 {
|
||||
t.Errorf("expected 2 selected, got %d", s.selectedCount())
|
||||
}
|
||||
if !s.checked[1] || !s.checked[2] {
|
||||
t.Error("expected item2 and item3 to be checked")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Prechecked_PreservesSelectionOrder", func(t *testing.T) {
|
||||
// order matters: first checked = default model
|
||||
s := newMultiSelectState(items, []string{"item3", "item1"})
|
||||
if len(s.checkOrder) != 2 {
|
||||
t.Fatalf("expected 2 in checkOrder, got %d", len(s.checkOrder))
|
||||
}
|
||||
if s.checkOrder[0] != 2 || s.checkOrder[1] != 0 {
|
||||
t.Errorf("expected checkOrder=[2,0] (item3 first), got %v", s.checkOrder)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Prechecked_IgnoresInvalidNames", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1", "nonexistent"})
|
||||
if s.selectedCount() != 1 {
|
||||
t.Errorf("expected 1 selected (nonexistent ignored), got %d", s.selectedCount())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Toggle_ChecksUncheckedItem", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.toggleItem()
|
||||
if !s.checked[0] {
|
||||
t.Error("expected item1 to be checked after toggle")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Toggle_UnchecksCheckedItem", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
s.toggleItem()
|
||||
if s.checked[0] {
|
||||
t.Error("expected item1 to be unchecked after toggle")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Toggle_RemovesFromCheckOrder", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1", "item2", "item3"})
|
||||
s.highlighted = 1 // toggle item2
|
||||
s.toggleItem()
|
||||
|
||||
if len(s.checkOrder) != 2 {
|
||||
t.Fatalf("expected 2 in checkOrder, got %d", len(s.checkOrder))
|
||||
}
|
||||
// should be [0, 2] (item1, item3) with item2 removed
|
||||
if s.checkOrder[0] != 0 || s.checkOrder[1] != 2 {
|
||||
t.Errorf("expected checkOrder=[0,2], got %v", s.checkOrder)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_TogglesWhenNotOnButton", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.handleInput(eventEnter, 0)
|
||||
if !s.checked[0] {
|
||||
t.Error("expected item1 to be checked after enter")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_OnButton_ReturnsSelection", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item2", "item1"})
|
||||
s.focusOnButton = true
|
||||
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
|
||||
if !done || err != nil {
|
||||
t.Errorf("expected done=true, err=nil, got done=%v, err=%v", done, err)
|
||||
}
|
||||
// result should preserve selection order
|
||||
if len(result) != 2 || result[0] != "item2" || result[1] != "item1" {
|
||||
t.Errorf("expected [item2, item1], got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_OnButton_EmptySelection_DoesNothing", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.focusOnButton = true
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if done || result != nil || err != nil {
|
||||
t.Errorf("expected (false, nil, nil), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Tab_SwitchesToButton_WhenHasSelection", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
s.handleInput(eventTab, 0)
|
||||
if !s.focusOnButton {
|
||||
t.Error("expected focus on button after tab")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Tab_DoesNothing_WhenNoSelection", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.handleInput(eventTab, 0)
|
||||
if s.focusOnButton {
|
||||
t.Error("tab should not focus button when nothing selected")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Tab_TogglesButtonFocus", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
s.handleInput(eventTab, 0)
|
||||
if !s.focusOnButton {
|
||||
t.Error("expected focus on button after first tab")
|
||||
}
|
||||
s.handleInput(eventTab, 0)
|
||||
if s.focusOnButton {
|
||||
t.Error("expected focus back on list after second tab")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
done, result, err := s.handleInput(eventEscape, 0)
|
||||
if !done || result != nil || err != errCancelled {
|
||||
t.Errorf("expected (true, nil, errCancelled), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IsDefault_TrueForFirstChecked", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item2", "item1"})
|
||||
if !(len(s.checkOrder) > 0 && s.checkOrder[0] == 1) {
|
||||
t.Error("expected item2 (idx 1) to be default (first checked)")
|
||||
}
|
||||
if len(s.checkOrder) > 0 && s.checkOrder[0] == 0 {
|
||||
t.Error("expected item1 (idx 0) to NOT be default")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IsDefault_FalseWhenNothingChecked", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
if len(s.checkOrder) > 0 && s.checkOrder[0] == 0 {
|
||||
t.Error("expected isDefault=false when nothing checked")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Down_MovesHighlight", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.highlighted != 1 {
|
||||
t.Errorf("expected highlighted=1, got %d", s.highlighted)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Up_MovesHighlight", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.highlighted = 1
|
||||
s.handleInput(eventUp, 0)
|
||||
if s.highlighted != 0 {
|
||||
t.Errorf("expected highlighted=0, got %d", s.highlighted)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Arrow_ReturnsFocusFromButton", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
s.focusOnButton = true
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.focusOnButton {
|
||||
t.Error("expected focus to return to list on arrow key")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Char_AppendsToFilter", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.handleInput(eventChar, 'x')
|
||||
if s.filter != "x" {
|
||||
t.Errorf("expected filter='x', got %q", s.filter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Char_ResetsHighlightAndScroll", func(t *testing.T) {
|
||||
manyItems := make([]selectItem, 15)
|
||||
for i := range manyItems {
|
||||
manyItems[i] = selectItem{Name: string(rune('a' + i))}
|
||||
}
|
||||
s := newMultiSelectState(manyItems, nil)
|
||||
s.highlighted = 10
|
||||
s.scrollOffset = 5
|
||||
|
||||
s.handleInput(eventChar, 'x')
|
||||
|
||||
if s.highlighted != 0 {
|
||||
t.Errorf("expected highlighted=0, got %d", s.highlighted)
|
||||
}
|
||||
if s.scrollOffset != 0 {
|
||||
t.Errorf("expected scrollOffset=0, got %d", s.scrollOffset)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace_RemovesLastFilterChar", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.filter = "test"
|
||||
s.handleInput(eventBackspace, 0)
|
||||
if s.filter != "tes" {
|
||||
t.Errorf("expected filter='tes', got %q", s.filter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace_RemovesFocusFromButton", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
s.filter = "x"
|
||||
s.focusOnButton = true
|
||||
s.handleInput(eventBackspace, 0)
|
||||
if s.focusOnButton {
|
||||
t.Error("expected focusOnButton=false after backspace")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseInput(t *testing.T) {
|
||||
t.Run("Enter", func(t *testing.T) {
|
||||
event, char, err := parseInput(bytes.NewReader([]byte{13}))
|
||||
if err != nil || event != eventEnter || char != 0 {
|
||||
t.Errorf("expected (eventEnter, 0, nil), got (%v, %v, %v)", event, char, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Escape", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{27}))
|
||||
if err != nil || event != eventEscape {
|
||||
t.Errorf("expected eventEscape, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CtrlC_TreatedAsEscape", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{3}))
|
||||
if err != nil || event != eventEscape {
|
||||
t.Errorf("expected eventEscape for Ctrl+C, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Tab", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{9}))
|
||||
if err != nil || event != eventTab {
|
||||
t.Errorf("expected eventTab, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{127}))
|
||||
if err != nil || event != eventBackspace {
|
||||
t.Errorf("expected eventBackspace, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UpArrow", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{27, 91, 65}))
|
||||
if err != nil || event != eventUp {
|
||||
t.Errorf("expected eventUp, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DownArrow", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{27, 91, 66}))
|
||||
if err != nil || event != eventDown {
|
||||
t.Errorf("expected eventDown, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PrintableChars", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
char byte
|
||||
}{
|
||||
{"lowercase", 'a'},
|
||||
{"uppercase", 'Z'},
|
||||
{"digit", '5'},
|
||||
{"space", ' '},
|
||||
{"tilde", '~'},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
event, char, err := parseInput(bytes.NewReader([]byte{tt.char}))
|
||||
if err != nil || event != eventChar || char != tt.char {
|
||||
t.Errorf("expected (eventChar, %q), got (%v, %q)", tt.char, event, char)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRenderSelect(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "item1", Description: "first item"},
|
||||
{Name: "item2"},
|
||||
}
|
||||
|
||||
t.Run("ShowsPromptAndItems", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
var buf bytes.Buffer
|
||||
lineCount := renderSelect(&buf, "Select:", s)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "Select:") {
|
||||
t.Error("expected prompt in output")
|
||||
}
|
||||
if !strings.Contains(output, "item1") {
|
||||
t.Error("expected item1 in output")
|
||||
}
|
||||
if !strings.Contains(output, "first item") {
|
||||
t.Error("expected description in output")
|
||||
}
|
||||
if !strings.Contains(output, "item2") {
|
||||
t.Error("expected item2 in output")
|
||||
}
|
||||
if lineCount != 3 { // 1 prompt + 2 items
|
||||
t.Errorf("expected 3 lines, got %d", lineCount)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EmptyFilteredList_ShowsNoMatches", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.filter = "xyz"
|
||||
var buf bytes.Buffer
|
||||
renderSelect(&buf, "Select:", s)
|
||||
|
||||
if !strings.Contains(buf.String(), "no matches") {
|
||||
t.Error("expected 'no matches' message")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LongList_ShowsRemainingCount", func(t *testing.T) {
|
||||
manyItems := make([]selectItem, 15)
|
||||
for i := range manyItems {
|
||||
manyItems[i] = selectItem{Name: string(rune('a' + i))}
|
||||
}
|
||||
s := newSelectState(manyItems)
|
||||
var buf bytes.Buffer
|
||||
renderSelect(&buf, "Select:", s)
|
||||
|
||||
// 15 items - 10 displayed = 5 more
|
||||
if !strings.Contains(buf.String(), "5 more") {
|
||||
t.Error("expected '5 more' indicator")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRenderMultiSelect(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "item1"},
|
||||
{Name: "item2"},
|
||||
}
|
||||
|
||||
t.Run("ShowsCheckboxes", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
var buf bytes.Buffer
|
||||
renderMultiSelect(&buf, "Select:", s)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "[x]") {
|
||||
t.Error("expected checked checkbox [x]")
|
||||
}
|
||||
if !strings.Contains(output, "[ ]") {
|
||||
t.Error("expected unchecked checkbox [ ]")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ShowsDefaultMarker", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
var buf bytes.Buffer
|
||||
renderMultiSelect(&buf, "Select:", s)
|
||||
|
||||
if !strings.Contains(buf.String(), "(default)") {
|
||||
t.Error("expected (default) marker for first checked item")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ShowsSelectedCount", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1", "item2"})
|
||||
var buf bytes.Buffer
|
||||
renderMultiSelect(&buf, "Select:", s)
|
||||
|
||||
if !strings.Contains(buf.String(), "2 selected") {
|
||||
t.Error("expected '2 selected' in output")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NoSelection_ShowsHelperText", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
var buf bytes.Buffer
|
||||
renderMultiSelect(&buf, "Select:", s)
|
||||
|
||||
if !strings.Contains(buf.String(), "Select at least one") {
|
||||
t.Error("expected 'Select at least one' helper text")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestErrCancelled(t *testing.T) {
|
||||
t.Run("NotNil", func(t *testing.T) {
|
||||
if errCancelled == nil {
|
||||
t.Error("errCancelled should not be nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Message", func(t *testing.T) {
|
||||
if errCancelled.Error() != "cancelled" {
|
||||
t.Errorf("expected 'cancelled', got %q", errCancelled.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Edge case tests for selector.go
|
||||
|
||||
// TestSelectState_SingleItem verifies that single item list works without crash.
|
||||
// List with only one item should still work.
|
||||
func TestSelectState_SingleItem(t *testing.T) {
|
||||
items := []selectItem{{Name: "only-one"}}
|
||||
|
||||
s := newSelectState(items)
|
||||
|
||||
// Down should do nothing (already at bottom)
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.selected != 0 {
|
||||
t.Errorf("down on single item: expected selected=0, got %d", s.selected)
|
||||
}
|
||||
|
||||
// Up should do nothing (already at top)
|
||||
s.handleInput(eventUp, 0)
|
||||
if s.selected != 0 {
|
||||
t.Errorf("up on single item: expected selected=0, got %d", s.selected)
|
||||
}
|
||||
|
||||
// Enter should select the only item
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if !done || result != "only-one" || err != nil {
|
||||
t.Errorf("enter on single item: expected (true, 'only-one', nil), got (%v, %q, %v)", done, result, err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSelectState_ExactlyMaxItems verifies boundary condition at maxDisplayedItems.
|
||||
// List with exactly maxDisplayedItems items should not scroll.
|
||||
func TestSelectState_ExactlyMaxItems(t *testing.T) {
|
||||
items := make([]selectItem, maxDisplayedItems)
|
||||
for i := range items {
|
||||
items[i] = selectItem{Name: string(rune('a' + i))}
|
||||
}
|
||||
|
||||
s := newSelectState(items)
|
||||
|
||||
// Move to last item
|
||||
for range maxDisplayedItems - 1 {
|
||||
s.handleInput(eventDown, 0)
|
||||
}
|
||||
|
||||
if s.selected != maxDisplayedItems-1 {
|
||||
t.Errorf("expected selected=%d, got %d", maxDisplayedItems-1, s.selected)
|
||||
}
|
||||
|
||||
// Should not scroll when exactly at max
|
||||
if s.scrollOffset != 0 {
|
||||
t.Errorf("expected scrollOffset=0 for exactly maxDisplayedItems, got %d", s.scrollOffset)
|
||||
}
|
||||
|
||||
// One more down should do nothing
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.selected != maxDisplayedItems-1 {
|
||||
t.Errorf("down at max: expected selected=%d, got %d", maxDisplayedItems-1, s.selected)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterItems_RegexSpecialChars verifies that filter is literal, not regex.
|
||||
// User typing "model.v1" shouldn't match "modelsv1".
|
||||
func TestFilterItems_RegexSpecialChars(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "model.v1"},
|
||||
{Name: "modelsv1"},
|
||||
{Name: "model-v1"},
|
||||
}
|
||||
|
||||
// Filter with dot should only match literal dot
|
||||
result := filterItems(items, "model.v1")
|
||||
if len(result) != 1 {
|
||||
t.Errorf("expected 1 exact match, got %d", len(result))
|
||||
}
|
||||
if len(result) > 0 && result[0].Name != "model.v1" {
|
||||
t.Errorf("expected 'model.v1', got %s", result[0].Name)
|
||||
}
|
||||
|
||||
// Other regex special chars should be literal too
|
||||
items2 := []selectItem{
|
||||
{Name: "test[0]"},
|
||||
{Name: "test0"},
|
||||
{Name: "test(1)"},
|
||||
}
|
||||
|
||||
result2 := filterItems(items2, "test[0]")
|
||||
if len(result2) != 1 || result2[0].Name != "test[0]" {
|
||||
t.Errorf("expected only 'test[0]', got %v", result2)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultiSelectState_DuplicateNames documents handling of duplicate item names.
|
||||
// itemIndex uses name as key - duplicates cause collision. This documents
|
||||
// the current behavior: the last index for a duplicate name is stored
|
||||
func TestMultiSelectState_DuplicateNames(t *testing.T) {
|
||||
// Duplicate names - this is an edge case that shouldn't happen in practice
|
||||
items := []selectItem{
|
||||
{Name: "duplicate"},
|
||||
{Name: "duplicate"},
|
||||
{Name: "unique"},
|
||||
}
|
||||
|
||||
s := newMultiSelectState(items, nil)
|
||||
|
||||
// DOCUMENTED BEHAVIOR: itemIndex maps name to LAST index
|
||||
// When there are duplicates, only the last occurrence's index is stored
|
||||
if s.itemIndex["duplicate"] != 1 {
|
||||
t.Errorf("itemIndex should map 'duplicate' to last index (1), got %d", s.itemIndex["duplicate"])
|
||||
}
|
||||
|
||||
// Toggle item at highlighted=0 (first "duplicate")
|
||||
// Due to name collision, toggleItem uses itemIndex["duplicate"] = 1
|
||||
// So it actually toggles the SECOND duplicate item, not the first
|
||||
s.toggleItem()
|
||||
|
||||
// This documents the potentially surprising behavior:
|
||||
// We toggled at highlighted=0, but itemIndex lookup returned 1
|
||||
if !s.checked[1] {
|
||||
t.Error("toggle should check index 1 (due to name collision in itemIndex)")
|
||||
}
|
||||
if s.checked[0] {
|
||||
t.Log("Note: index 0 is NOT checked, even though highlighted=0 (name collision behavior)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSelectState_FilterReducesBelowSelection verifies selection resets when filter reduces list.
|
||||
// Prevents index-out-of-bounds on next keystroke
|
||||
func TestSelectState_FilterReducesBelowSelection(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "apple"},
|
||||
{Name: "banana"},
|
||||
{Name: "cherry"},
|
||||
}
|
||||
|
||||
s := newSelectState(items)
|
||||
s.selected = 2 // Select "cherry"
|
||||
|
||||
// Type a filter that removes cherry from results
|
||||
s.handleInput(eventChar, 'a') // Filter to "a" - matches "apple" and "banana"
|
||||
|
||||
// Selection should reset to 0
|
||||
if s.selected != 0 {
|
||||
t.Errorf("expected selected=0 after filter, got %d", s.selected)
|
||||
}
|
||||
|
||||
filtered := s.filtered()
|
||||
if len(filtered) != 2 {
|
||||
t.Errorf("expected 2 filtered items, got %d", len(filtered))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterItems_UnicodeCharacters verifies filtering works with UTF-8.
|
||||
// Model names might contain unicode characters
|
||||
func TestFilterItems_UnicodeCharacters(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "llama-日本語"},
|
||||
{Name: "模型-chinese"},
|
||||
{Name: "émoji-🦙"},
|
||||
{Name: "regular-model"},
|
||||
}
|
||||
|
||||
t.Run("filter japanese", func(t *testing.T) {
|
||||
result := filterItems(items, "日本")
|
||||
if len(result) != 1 || result[0].Name != "llama-日本語" {
|
||||
t.Errorf("expected llama-日本語, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter chinese", func(t *testing.T) {
|
||||
result := filterItems(items, "模型")
|
||||
if len(result) != 1 || result[0].Name != "模型-chinese" {
|
||||
t.Errorf("expected 模型-chinese, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter emoji", func(t *testing.T) {
|
||||
result := filterItems(items, "🦙")
|
||||
if len(result) != 1 || result[0].Name != "émoji-🦙" {
|
||||
t.Errorf("expected émoji-🦙, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter accented char", func(t *testing.T) {
|
||||
result := filterItems(items, "émoji")
|
||||
if len(result) != 1 || result[0].Name != "émoji-🦙" {
|
||||
t.Errorf("expected émoji-🦙, got %v", result)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestMultiSelectState_FilterReducesBelowHighlight verifies highlight resets when filter reduces list.
|
||||
func TestMultiSelectState_FilterReducesBelowHighlight(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "apple"},
|
||||
{Name: "banana"},
|
||||
{Name: "cherry"},
|
||||
}
|
||||
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.highlighted = 2 // Highlight "cherry"
|
||||
|
||||
// Type a filter that removes cherry
|
||||
s.handleInput(eventChar, 'a')
|
||||
|
||||
if s.highlighted != 0 {
|
||||
t.Errorf("expected highlighted=0 after filter, got %d", s.highlighted)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultiSelectState_EmptyItems verifies handling of empty item list.
|
||||
// Empty list should be handled gracefully.
|
||||
func TestMultiSelectState_EmptyItems(t *testing.T) {
|
||||
s := newMultiSelectState([]selectItem{}, nil)
|
||||
|
||||
// Toggle should not panic on empty list
|
||||
s.toggleItem()
|
||||
|
||||
if s.selectedCount() != 0 {
|
||||
t.Errorf("expected 0 selected for empty list, got %d", s.selectedCount())
|
||||
}
|
||||
|
||||
// Render should handle empty list
|
||||
var buf bytes.Buffer
|
||||
lineCount := renderMultiSelect(&buf, "Select:", s)
|
||||
if lineCount == 0 {
|
||||
t.Error("renderMultiSelect should produce output even for empty list")
|
||||
}
|
||||
if !strings.Contains(buf.String(), "no matches") {
|
||||
t.Error("expected 'no matches' for empty list")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSelectState_RenderWithDescriptions verifies rendering items with descriptions.
|
||||
func TestSelectState_RenderWithDescriptions(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "item1", Description: "First item description"},
|
||||
{Name: "item2", Description: ""},
|
||||
{Name: "item3", Description: "Third item"},
|
||||
}
|
||||
|
||||
s := newSelectState(items)
|
||||
var buf bytes.Buffer
|
||||
renderSelect(&buf, "Select:", s)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "First item description") {
|
||||
t.Error("expected description to be rendered")
|
||||
}
|
||||
if !strings.Contains(output, "item2") {
|
||||
t.Error("expected item without description to be rendered")
|
||||
}
|
||||
}
|
||||
@@ -159,6 +159,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
sb.WriteString(before)
|
||||
if !ok {
|
||||
fmt.Fprintln(&sb)
|
||||
scanner.Prompt.UseAlt = true
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,10 @@ import (
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
@@ -69,6 +73,9 @@ func (p *glm4MoeLiteModel) KV(t *Tokenizer) KV {
|
||||
kv["glm4moelite.rope.dimension_count"] = p.QKRopeHeadDim
|
||||
kv["glm4moelite.rope.freq_base"] = cmp.Or(p.RopeTheta, float32(1000000.0))
|
||||
|
||||
kv["glm4moelite.attention.key_length_mla"] = p.KVLoraRank + p.QKRopeHeadDim
|
||||
kv["glm4moelite.attention.value_length_mla"] = p.KVLoraRank
|
||||
|
||||
kv["tokenizer.ggml.pre"] = "glm4"
|
||||
|
||||
return kv
|
||||
@@ -100,6 +107,67 @@ func (p *glm4MoeLiteModel) Replacements() []string {
|
||||
}
|
||||
}
|
||||
|
||||
// repackKVB extracts K or V from the combined KV_B tensor for MLA absorption.
|
||||
// K output row-major: [n_head, kv_lora_rank, qk_nope] -> GGML ne[]={qk_nope, kv_lora_rank, n_head}
|
||||
// V output row-major: [n_head, v_head, kv_lora_rank] -> GGML ne[]={kv_lora_rank, v_head, n_head}
|
||||
func (p *glm4MoeLiteModel) repackKVB(extractK bool, kvFirst bool, numHeads int) Repacker {
|
||||
qkNope := int(p.QKNopeHeadDim)
|
||||
vHeadDim := int(p.VHeadDim)
|
||||
kvLoraRank := int(p.KVLoraRank)
|
||||
kvPerHead := qkNope + vHeadDim
|
||||
|
||||
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i := range shape {
|
||||
dims[i] = int(shape[i])
|
||||
}
|
||||
|
||||
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
var err error
|
||||
|
||||
// Normalize to [n_head * (qk_nope + v_head), kv_lora_rank] layout
|
||||
if kvFirst {
|
||||
tt, err = tensor.Transpose(tt, 1, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tt = tensor.Materialize(tt)
|
||||
}
|
||||
|
||||
// Reshape to [n_head, qk_nope + v_head, kv_lora_rank]
|
||||
if err := tt.Reshape(numHeads, kvPerHead, kvLoraRank); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if extractK {
|
||||
// Slice K: [n_head, qk_nope, kv_lora_rank]
|
||||
tt, err = tt.Slice(nil, tensor.S(0, qkNope), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tt = tensor.Materialize(tt)
|
||||
// Transpose to [n_head, kv_lora_rank, qk_nope]
|
||||
tt, err = tensor.Transpose(tt, 0, 2, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tt = tensor.Materialize(tt)
|
||||
} else {
|
||||
// Slice V: [n_head, v_head, kv_lora_rank] - already correct layout
|
||||
tt, err = tt.Slice(nil, tensor.S(qkNope, kvPerHead), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tt = tensor.Materialize(tt)
|
||||
}
|
||||
|
||||
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return native.VectorF32(tt.(*tensor.Dense))
|
||||
}
|
||||
}
|
||||
|
||||
func (p *glm4MoeLiteModel) Tensors(s []Tensor) (out []*ggml.Tensor) {
|
||||
merges := make([]merge, p.HiddenLayers*3)
|
||||
for i := range p.HiddenLayers {
|
||||
@@ -139,6 +207,52 @@ func (p *glm4MoeLiteModel) Tensors(s []Tensor) (out []*ggml.Tensor) {
|
||||
slog.Debug("skipping layer", "name", t.Name())
|
||||
continue
|
||||
}
|
||||
|
||||
// Split attn_kv_b into separate attn_k_b and attn_v_b for MLA absorption
|
||||
if strings.HasSuffix(t.Name(), ".attn_kv_b.weight") {
|
||||
qkNope := int(p.QKNopeHeadDim)
|
||||
vHeadDim := int(p.VHeadDim)
|
||||
kvLoraRank := int(p.KVLoraRank)
|
||||
kvPerHead := qkNope + vHeadDim
|
||||
numHeads := int(p.NumAttentionHeads)
|
||||
kvFirst := true
|
||||
if len(t.Shape()) == 2 {
|
||||
switch {
|
||||
case int(t.Shape()[0]) == kvLoraRank:
|
||||
if kvPerHead > 0 && int(t.Shape()[1])%kvPerHead == 0 {
|
||||
numHeads = int(t.Shape()[1]) / kvPerHead
|
||||
}
|
||||
kvFirst = true
|
||||
case int(t.Shape()[1]) == kvLoraRank:
|
||||
if kvPerHead > 0 && int(t.Shape()[0])%kvPerHead == 0 {
|
||||
numHeads = int(t.Shape()[0]) / kvPerHead
|
||||
}
|
||||
kvFirst = false
|
||||
default:
|
||||
slog.Warn("glm4moelite: unexpected attn_kv_b layout", "name", t.Name(), "shape", t.Shape())
|
||||
}
|
||||
}
|
||||
|
||||
kTensor := t.Clone()
|
||||
kTensor.SetRepacker(p.repackKVB(true, kvFirst, numHeads))
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(t.Name(), "attn_kv_b", "attn_k_b", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: []uint64{uint64(numHeads), uint64(kvLoraRank), uint64(qkNope)},
|
||||
WriterTo: kTensor,
|
||||
})
|
||||
|
||||
vTensor := t.Clone()
|
||||
vTensor.SetRepacker(p.repackKVB(false, kvFirst, numHeads))
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(t.Name(), "attn_kv_b", "attn_v_b", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: []uint64{uint64(numHeads), uint64(vHeadDim), uint64(kvLoraRank)},
|
||||
WriterTo: vTensor,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
|
||||
@@ -4,16 +4,6 @@ title: Anthropic compatibility
|
||||
|
||||
Ollama provides compatibility with the [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) to help connect existing applications to Ollama, including tools like Claude Code.
|
||||
|
||||
## Recommended models
|
||||
|
||||
For coding use cases, models like `glm-4.7:cloud`, `minimax-m2.1:cloud`, and `qwen3-coder` are recommended.
|
||||
|
||||
Pull a model before use:
|
||||
```shell
|
||||
ollama pull qwen3-coder
|
||||
ollama pull glm-4.7:cloud
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Environment variables
|
||||
@@ -22,8 +12,8 @@ To use Ollama with tools that expect the Anthropic API (like Claude Code), set t
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
|
||||
export ANTHROPIC_API_KEY="" # required but ignored
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama # required but ignored
|
||||
```
|
||||
|
||||
### Simple `/v1/messages` example
|
||||
@@ -245,10 +235,41 @@ curl -X POST http://localhost:11434/v1/messages \
|
||||
|
||||
## Using with Claude Code
|
||||
|
||||
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
|
||||
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend.
|
||||
|
||||
### Recommended models
|
||||
|
||||
For coding use cases, models like `glm-4.7`, `minimax-m2.1`, and `qwen3-coder` are recommended.
|
||||
|
||||
Download a model before use:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
ollama pull qwen3-coder
|
||||
```
|
||||
> Note: Qwen 3 coder is a 30B parameter model requiring at least 24GB of VRAM to run smoothly. More is required for longer context lengths.
|
||||
|
||||
```shell
|
||||
ollama pull glm-4.7:cloud
|
||||
```
|
||||
|
||||
### Quick setup
|
||||
|
||||
```shell
|
||||
ollama launch claude
|
||||
```
|
||||
|
||||
This will prompt you to select a model, configure Claude Code automatically, and launch it. To configure without launching:
|
||||
|
||||
```shell
|
||||
ollama launch claude --config
|
||||
```
|
||||
|
||||
### Manual setup
|
||||
|
||||
Set the environment variables and run Claude Code:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY="" claude --model qwen3-coder
|
||||
```
|
||||
|
||||
Or set the environment variables in your shell profile:
|
||||
@@ -256,19 +277,13 @@ Or set the environment variables in your shell profile:
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama
|
||||
export ANTHROPIC_API_KEY=""
|
||||
```
|
||||
|
||||
Then run Claude Code with any Ollama model:
|
||||
|
||||
```shell
|
||||
# Local models
|
||||
claude --model qwen3-coder
|
||||
claude --model gpt-oss:20b
|
||||
|
||||
# Cloud models
|
||||
claude --model glm-4.7:cloud
|
||||
claude --model minimax-m2.1:cloud
|
||||
```
|
||||
|
||||
## Endpoints
|
||||
|
||||
41
docs/cli.mdx
41
docs/cli.mdx
@@ -8,6 +8,47 @@ title: CLI Reference
|
||||
ollama run gemma3
|
||||
```
|
||||
|
||||
### Launch integrations
|
||||
|
||||
```
|
||||
ollama launch
|
||||
```
|
||||
|
||||
Configure and launch external applications to use Ollama models. This provides an interactive way to set up and start integrations with supported apps.
|
||||
|
||||
#### Supported integrations
|
||||
|
||||
- **OpenCode** - Open-source coding assistant
|
||||
- **Claude Code** - Anthropic's agentic coding tool
|
||||
- **Codex** - OpenAI's coding assistant
|
||||
- **Droid** - Factory's AI coding agent
|
||||
|
||||
#### Examples
|
||||
|
||||
Launch an integration interactively:
|
||||
|
||||
```
|
||||
ollama launch
|
||||
```
|
||||
|
||||
Launch a specific integration:
|
||||
|
||||
```
|
||||
ollama launch claude
|
||||
```
|
||||
|
||||
Launch with a specific model:
|
||||
|
||||
```
|
||||
ollama launch claude --model qwen3-coder
|
||||
```
|
||||
|
||||
Configure without launching:
|
||||
|
||||
```
|
||||
ollama launch droid --config
|
||||
```
|
||||
|
||||
#### Multiline input
|
||||
|
||||
For multiline input, you can wrap text with `"""`:
|
||||
|
||||
@@ -3,8 +3,6 @@ title: Cloud
|
||||
sidebarTitle: Cloud
|
||||
---
|
||||
|
||||
<Info>Ollama's cloud is currently in preview.</Info>
|
||||
|
||||
## Cloud Models
|
||||
|
||||
Ollama's cloud models are a new kind of model in Ollama that can run without a powerful GPU. Instead, cloud models are automatically offloaded to Ollama's cloud service while offering the same capabilities as local models, making it possible to keep using your local tools while running larger models that wouldn't fit on a personal computer.
|
||||
|
||||
@@ -8,7 +8,7 @@ Context length is the maximum number of tokens that the model has access to in m
|
||||
The default context length in Ollama is 4096 tokens.
|
||||
</Note>
|
||||
|
||||
Tasks which require large context like web search, agents, and coding tools should be set to at least 32000 tokens.
|
||||
Tasks which require large context like web search, agents, and coding tools should be set to at least 64000 tokens.
|
||||
|
||||
## Setting context length
|
||||
|
||||
@@ -24,7 +24,7 @@ Change the slider in the Ollama app under settings to your desired context lengt
|
||||
### CLI
|
||||
If editing the context length for Ollama is not possible, the context length can also be updated when serving Ollama.
|
||||
```
|
||||
OLLAMA_CONTEXT_LENGTH=32000 ollama serve
|
||||
OLLAMA_CONTEXT_LENGTH=64000 ollama serve
|
||||
```
|
||||
|
||||
### Check allocated context length and model offloading
|
||||
|
||||
@@ -102,18 +102,19 @@
|
||||
"group": "Integrations",
|
||||
"pages": [
|
||||
"/integrations/claude-code",
|
||||
"/integrations/vscode",
|
||||
"/integrations/jetbrains",
|
||||
"/integrations/codex",
|
||||
"/integrations/cline",
|
||||
"/integrations/codex",
|
||||
"/integrations/droid",
|
||||
"/integrations/goose",
|
||||
"/integrations/zed",
|
||||
"/integrations/roo-code",
|
||||
"/integrations/jetbrains",
|
||||
"/integrations/marimo",
|
||||
"/integrations/n8n",
|
||||
"/integrations/xcode",
|
||||
"/integrations/onyx",
|
||||
"/integrations/marimo"
|
||||
"/integrations/opencode",
|
||||
"/integrations/roo-code",
|
||||
"/integrations/vscode",
|
||||
"/integrations/xcode",
|
||||
"/integrations/zed"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -9,7 +9,7 @@ sidebarTitle: Welcome
|
||||
|
||||
<CardGroup cols={2}>
|
||||
<Card title="Quickstart" icon="rocket" href="/quickstart">
|
||||
Get up and running with your first model
|
||||
Get up and running with your first model or integrate Ollama with your favorite tools
|
||||
</Card>
|
||||
<Card
|
||||
title="Download Ollama"
|
||||
|
||||
@@ -4,7 +4,7 @@ title: Claude Code
|
||||
|
||||
Claude Code is Anthropic's agentic coding tool that can read, modify, and execute code in your working directory.
|
||||
|
||||
Open models can be used with Claude Code through Ollama's Anthropic-compatible API, enabling you to use models such as `qwen3-coder`, `gpt-oss:20b`, or other models.
|
||||
Open models can be used with Claude Code through Ollama's Anthropic-compatible API, enabling you to use models such as `glm-4.7`, `qwen3-coder`, `gpt-oss`.
|
||||
|
||||

|
||||
|
||||
@@ -26,12 +26,27 @@ irm https://claude.ai/install.ps1 | iex
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
### Quick setup
|
||||
|
||||
```shell
|
||||
ollama launch claude
|
||||
```
|
||||
|
||||
To configure without launching:
|
||||
|
||||
```shell
|
||||
ollama launch claude --config
|
||||
```
|
||||
|
||||
### Manual setup
|
||||
|
||||
Claude Code connects to Ollama using the Anthropic-compatible API.
|
||||
|
||||
1. Set the environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama
|
||||
export ANTHROPIC_API_KEY=""
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
```
|
||||
|
||||
@@ -44,35 +59,17 @@ claude --model gpt-oss:20b
|
||||
Or run with environment variables inline:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 claude --model gpt-oss:20b
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY="" claude --model qwen3-coder
|
||||
```
|
||||
|
||||
**Note:** Claude Code requires a large context window. We recommend at least 32K tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
|
||||
|
||||
## Connecting to ollama.com
|
||||
|
||||
1. Create an [API key](https://ollama.com/settings/keys) on ollama.com
|
||||
2. Set the environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_BASE_URL=https://ollama.com
|
||||
export ANTHROPIC_API_KEY=<your-api-key>
|
||||
```
|
||||
|
||||
3. Run Claude Code with a cloud model:
|
||||
|
||||
```shell
|
||||
claude --model glm-4.7:cloud
|
||||
```
|
||||
**Note:** Claude Code requires a large context window. We recommend at least 64k tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
|
||||
|
||||
## Recommended Models
|
||||
|
||||
### Cloud models
|
||||
- `glm-4.7:cloud` - High-performance cloud model
|
||||
- `minimax-m2.1:cloud` - Fast cloud model
|
||||
- `qwen3-coder:480b` - Large coding model
|
||||
- `qwen3-coder`
|
||||
- `glm-4.7`
|
||||
- `gpt-oss:20b`
|
||||
- `gpt-oss:120b`
|
||||
|
||||
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
|
||||
|
||||
### Local models
|
||||
- `qwen3-coder` - Excellent for coding tasks
|
||||
- `gpt-oss:20b` - Strong general-purpose model
|
||||
- `gpt-oss:120b` - Larger general-purpose model for more complex tasks
|
||||
@@ -13,7 +13,21 @@ npm install -g @openai/codex
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
<Note>Codex requires a larger context window. It is recommended to use a context window of at least 32K tokens.</Note>
|
||||
<Note>Codex requires a larger context window. It is recommended to use a context window of at least 64k tokens.</Note>
|
||||
|
||||
### Quick setup
|
||||
|
||||
```
|
||||
ollama launch codex
|
||||
```
|
||||
|
||||
To configure without launching:
|
||||
|
||||
```shell
|
||||
ollama launch codex --config
|
||||
```
|
||||
|
||||
### Manual setup
|
||||
|
||||
To use `codex` with Ollama, use the `--oss` flag:
|
||||
|
||||
|
||||
@@ -11,10 +11,24 @@ Install the [Droid CLI](https://factory.ai/):
|
||||
curl -fsSL https://app.factory.ai/cli | sh
|
||||
```
|
||||
|
||||
<Note>Droid requires a larger context window. It is recommended to use a context window of at least 32K tokens. See [Context length](/context-length) for more information.</Note>
|
||||
<Note>Droid requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
### Quick setup
|
||||
|
||||
```bash
|
||||
ollama launch droid
|
||||
```
|
||||
|
||||
To configure without launching:
|
||||
|
||||
```shell
|
||||
ollama launch droid --config
|
||||
```
|
||||
|
||||
### Manual setup
|
||||
|
||||
Add a local configuration block to `~/.factory/config.json`:
|
||||
|
||||
```json
|
||||
@@ -73,4 +87,4 @@ Add the cloud configuration block to `~/.factory/config.json`:
|
||||
}
|
||||
```
|
||||
|
||||
Run `droid` in a new terminal to load the new settings.
|
||||
Run `droid` in a new terminal to load the new settings.
|
||||
|
||||
106
docs/integrations/opencode.mdx
Normal file
106
docs/integrations/opencode.mdx
Normal file
@@ -0,0 +1,106 @@
|
||||
---
|
||||
title: OpenCode
|
||||
---
|
||||
|
||||
OpenCode is an open-source AI coding assistant that runs in your terminal.
|
||||
|
||||
## Install
|
||||
|
||||
Install the [OpenCode CLI](https://opencode.ai):
|
||||
|
||||
```bash
|
||||
curl -fsSL https://opencode.ai/install.sh | bash
|
||||
```
|
||||
|
||||
<Note>OpenCode requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
### Quick setup
|
||||
|
||||
```bash
|
||||
ollama launch opencode
|
||||
```
|
||||
|
||||
To configure without launching:
|
||||
|
||||
```shell
|
||||
ollama launch opencode --config
|
||||
```
|
||||
|
||||
### Manual setup
|
||||
|
||||
Add a configuration block to `~/.config/opencode/opencode.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"$schema": "https://opencode.ai/config.json",
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama",
|
||||
"options": {
|
||||
"baseURL": "http://localhost:11434/v1"
|
||||
},
|
||||
"models": {
|
||||
"qwen3-coder": {
|
||||
"name": "qwen3-coder"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Cloud Models
|
||||
|
||||
`glm-4.7:cloud` is the recommended model for use with OpenCode.
|
||||
|
||||
Add the cloud configuration to `~/.config/opencode/opencode.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"$schema": "https://opencode.ai/config.json",
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama",
|
||||
"options": {
|
||||
"baseURL": "http://localhost:11434/v1"
|
||||
},
|
||||
"models": {
|
||||
"glm-4.7:cloud": {
|
||||
"name": "glm-4.7:cloud"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Connecting to ollama.com
|
||||
|
||||
1. Create an [API key](https://ollama.com/settings/keys) from ollama.com and export it as `OLLAMA_API_KEY`.
|
||||
2. Update `~/.config/opencode/opencode.json` to point to ollama.com:
|
||||
|
||||
```json
|
||||
{
|
||||
"$schema": "https://opencode.ai/config.json",
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama Cloud",
|
||||
"options": {
|
||||
"baseURL": "https://ollama.com/v1"
|
||||
},
|
||||
"models": {
|
||||
"glm-4.7:cloud": {
|
||||
"name": "glm-4.7:cloud"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Run `opencode` in a new terminal to load the new settings.
|
||||
@@ -18,13 +18,13 @@ This quickstart will walk your through running your first model with Ollama. To
|
||||
<Tab title="CLI">
|
||||
Open a terminal and run the command:
|
||||
|
||||
```
|
||||
```sh
|
||||
ollama run gemma3
|
||||
```
|
||||
|
||||
</Tab>
|
||||
<Tab title="cURL">
|
||||
```
|
||||
```sh
|
||||
ollama pull gemma3
|
||||
```
|
||||
|
||||
@@ -45,13 +45,13 @@ This quickstart will walk your through running your first model with Ollama. To
|
||||
<Tab title="Python">
|
||||
Start by downloading a model:
|
||||
|
||||
```
|
||||
```sh
|
||||
ollama pull gemma3
|
||||
```
|
||||
|
||||
Then install Ollama's Python library:
|
||||
|
||||
```
|
||||
```sh
|
||||
pip install ollama
|
||||
```
|
||||
|
||||
@@ -101,3 +101,42 @@ This quickstart will walk your through running your first model with Ollama. To
|
||||
</Tabs>
|
||||
|
||||
See a full list of available models [here](https://ollama.com/models).
|
||||
|
||||
## Coding
|
||||
|
||||
For coding use cases, we recommend using the `glm-4.7-flash` model.
|
||||
|
||||
Note: this model requires 23 GB of VRAM with 64000 tokens context length.
|
||||
```sh
|
||||
ollama pull glm-4.7-flash
|
||||
```
|
||||
|
||||
Alternatively, you can use a more powerful cloud model (with full context length):
|
||||
```sh
|
||||
ollama pull glm-4.7:cloud
|
||||
```
|
||||
|
||||
Use `ollama launch` to quickly set up a coding tool with Ollama models:
|
||||
|
||||
```sh
|
||||
ollama launch
|
||||
```
|
||||
|
||||
### Supported integrations
|
||||
|
||||
- [OpenCode](/integrations/opencode) - Open-source coding assistant
|
||||
- [Claude Code](/integrations/claude-code) - Anthropic's agentic coding tool
|
||||
- [Codex](/integrations/codex) - OpenAI's coding assistant
|
||||
- [Droid](/integrations/droid) - Factory's AI coding agent
|
||||
|
||||
### Launch with a specific model
|
||||
|
||||
```sh
|
||||
ollama launch claude --model glm-4.7-flash
|
||||
```
|
||||
|
||||
### Configure without launching
|
||||
|
||||
```sh
|
||||
ollama launch claude --config
|
||||
```
|
||||
|
||||
@@ -0,0 +1,309 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: nobody <>
|
||||
Date: Sat, 24 Jan 2026 02:31:01 +0000
|
||||
Subject: [PATCH] ggml: enable MLA flash attention for GLM-4.7-flash
|
||||
|
||||
Add support for gqa_ratio 4 in MLA flash attention kernels. GLM-4.7-flash
|
||||
uses head size 576 with gqa_ratio 4, which was previously only supported
|
||||
for gqa_ratio 16 (DeepSeek).
|
||||
|
||||
Metal changes:
|
||||
- Enable head size 576 for flash attention
|
||||
- Increase simdgroups to 8 for large heads (>=512)
|
||||
- Add case 8 kernel dispatch for 8 simdgroups
|
||||
|
||||
CUDA changes:
|
||||
- Add gqa_ratio 4 support for head 576/512
|
||||
- Add tile configs for (576, 512, 4) and (576, 512, 8)
|
||||
- Add MMA config cases for ncols 4
|
||||
- Add template instances for ncols2=4
|
||||
- Fix nbatch_fa values in nvidia_fp32 config (32->64)
|
||||
---
|
||||
ggml/src/ggml-cuda/fattn-mma-f16.cuh | 40 +++++++++++++++----
|
||||
ggml/src/ggml-cuda/fattn-tile.cuh | 16 ++++++++
|
||||
ggml/src/ggml-cuda/fattn.cu | 12 ++++--
|
||||
...ttn-mma-f16-instance-ncols1_16-ncols2_4.cu | 1 +
|
||||
...attn-mma-f16-instance-ncols1_2-ncols2_4.cu | 1 +
|
||||
...attn-mma-f16-instance-ncols1_4-ncols2_4.cu | 1 +
|
||||
...attn-mma-f16-instance-ncols1_8-ncols2_4.cu | 1 +
|
||||
ggml/src/ggml-metal/ggml-metal-device.m | 8 +---
|
||||
ggml/src/ggml-metal/ggml-metal-ops.cpp | 2 +-
|
||||
ggml/src/ggml-metal/ggml-metal.metal | 1 +
|
||||
10 files changed, 64 insertions(+), 19 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
|
||||
index 7bd1044c1..3dea2205e 100644
|
||||
--- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh
|
||||
+++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
|
||||
@@ -66,7 +66,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
|
||||
|
||||
- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false);
|
||||
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 128, 1, false);
|
||||
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
|
||||
@@ -80,7 +81,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
|
||||
- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
|
||||
@@ -89,7 +91,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
}
|
||||
|
||||
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
|
||||
- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false);
|
||||
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 64, 1, false);
|
||||
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false);
|
||||
@@ -397,7 +400,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
constexpr int cols_per_warp = T_B_KQ::I;
|
||||
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
|
||||
- constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
||||
+ constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
|
||||
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
|
||||
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
|
||||
@@ -467,7 +470,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
}
|
||||
}
|
||||
} else {
|
||||
- static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
|
||||
load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
|
||||
@@ -479,8 +481,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
T_A_KQ K_A;
|
||||
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
|
||||
|
||||
- // Wide version of KQ_C is column-major => swap A and B.
|
||||
- mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
|
||||
+ if constexpr (cols_per_warp == 8) {
|
||||
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
||||
+ } else {
|
||||
+ // Wide version of KQ_C is column-major
|
||||
+#if defined(AMD_WMMA_AVAILABLE)
|
||||
+ // RDNA matrix C is column-major.
|
||||
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
||||
+#else
|
||||
+ // swap A and B for CUDA.
|
||||
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
|
||||
+#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
+ }
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -841,7 +853,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
|
||||
constexpr int cols_per_warp = T_B_KQ::I;
|
||||
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
|
||||
- constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
||||
+ constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
|
||||
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
|
||||
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
|
||||
@@ -1353,6 +1365,13 @@ static __global__ void flash_attn_ext_f16(
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
+#ifdef VOLTA_MMA_AVAILABLE
|
||||
+ if (ncols1*ncols2 < 32) {
|
||||
+ NO_DEVICE_CODE;
|
||||
+ return;
|
||||
+ }
|
||||
+#endif // VOLTA_MMA_AVAILABLE
|
||||
+
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
if (ncols1*ncols2 > 32) {
|
||||
NO_DEVICE_CODE;
|
||||
@@ -1585,3 +1604,8 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
|
||||
+
|
||||
+// For GLM 4.7 Flash
|
||||
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
||||
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
||||
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
||||
diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh
|
||||
index 7c4d6fe67..371be7442 100644
|
||||
--- a/ggml/src/ggml-cuda/fattn-tile.cuh
|
||||
+++ b/ggml/src/ggml-cuda/fattn-tile.cuh
|
||||
@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
|
||||
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
|
||||
return 0;
|
||||
@@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
|
||||
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64)
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
|
||||
|
||||
return 0;
|
||||
@@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
|
||||
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
|
||||
|
||||
@@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
|
||||
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)
|
||||
|
||||
@@ -1187,6 +1195,14 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
+ if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
||||
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
|
||||
+ return;
|
||||
+ }
|
||||
+ if (use_gqa_opt && gqa_ratio % 4 == 0) {
|
||||
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
|
||||
+ return;
|
||||
+ }
|
||||
}
|
||||
|
||||
if constexpr (DV <= 256) {
|
||||
diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu
|
||||
index 015540666..1693479cb 100644
|
||||
--- a/ggml/src/ggml-cuda/fattn.cu
|
||||
+++ b/ggml/src/ggml-cuda/fattn.cu
|
||||
@@ -111,7 +111,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
|
||||
break;
|
||||
case 576: {
|
||||
- // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
|
||||
+ // For Deepseek/GLM4, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
|
||||
GGML_ASSERT(V->ne[0] == 512);
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
@@ -121,8 +121,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
- GGML_ASSERT(gqa_ratio % 16 == 0);
|
||||
- ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
+ GGML_ASSERT(gqa_ratio % 4 == 0);
|
||||
+ if (gqa_ratio % 16 == 0) {
|
||||
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
+ } else {
|
||||
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
|
||||
+ }
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -251,7 +255,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
if (V->ne[0] != 512) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
- if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
|
||||
+ if (!gqa_opt_applies || gqa_ratio % 4 != 0) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
|
||||
index 2074e954a..517993cb0 100644
|
||||
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
|
||||
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
|
||||
+DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
||||
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
|
||||
index 24c64cf00..97b19c67a 100644
|
||||
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
|
||||
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
|
||||
+DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
|
||||
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
|
||||
index 1ada657f1..989626dfa 100644
|
||||
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
|
||||
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
|
||||
+DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
||||
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
|
||||
index 86d4ffae2..173de7aac 100644
|
||||
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
|
||||
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
|
||||
+DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m
|
||||
index f24270bb1..7b5ee968c 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-device.m
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-device.m
|
||||
@@ -1071,12 +1071,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
op->src[0]->ne[0] != 112 &&
|
||||
op->src[0]->ne[0] != 128 &&
|
||||
op->src[0]->ne[0] != 192 &&
|
||||
- op->src[0]->ne[0] != 256) {
|
||||
- return false;
|
||||
- }
|
||||
- if (op->src[0]->ne[0] == 576) {
|
||||
- // DeepSeek sizes
|
||||
- // TODO: disabled for now, until optmized
|
||||
+ op->src[0]->ne[0] != 256 &&
|
||||
+ op->src[0]->ne[0] != 576) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[1]->type != op->src[2]->type) {
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
index e99c1763f..80864f303 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
@@ -2456,7 +2456,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
// simdgroups per threadgroup (a.k.a. warps)
|
||||
//nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
|
||||
- int32_t nsg = 4;
|
||||
+ int32_t nsg = ne00 >= 512 ? 8 : 4;
|
||||
|
||||
const size_t smem = FATTN_SMEM(nsg);
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
index c98d269d1..d33c16079 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
@@ -6166,6 +6166,7 @@ kernel void kernel_flash_attn_ext(
|
||||
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
|
||||
//case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
|
||||
case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
|
||||
+ case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
|
||||
}
|
||||
#undef FWD_TMPL
|
||||
#undef FWD_ARGS
|
||||
@@ -609,3 +609,49 @@ func ImageGenerationsMiddleware() gin.HandlerFunc {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func ImageEditsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req openai.ImageEditRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Prompt == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "prompt is required"))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required"))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Image == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "image is required"))
|
||||
return
|
||||
}
|
||||
|
||||
genReq, err := openai.FromImageEditRequest(req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &ImageWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1112,3 +1112,129 @@ func TestImageWriterResponse(t *testing.T) {
|
||||
t.Errorf("expected image data 'dGVzdC1pbWFnZS1kYXRh', got %s", imageResp.Data[0].B64JSON)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageEditsMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
body string
|
||||
req api.GenerateRequest
|
||||
err openai.ErrorResponse
|
||||
}
|
||||
|
||||
var capturedRequest *api.GenerateRequest
|
||||
|
||||
// Base64-encoded test image (1x1 pixel PNG)
|
||||
testImage := ""
|
||||
decodedImage, _ := base64.StdEncoding.DecodeString("iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=")
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "image edit basic",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "make it blue",
|
||||
"image": "` + testImage + `"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "make it blue",
|
||||
Images: []api.ImageData{decodedImage},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image edit with size",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "make it blue",
|
||||
"image": "` + testImage + `",
|
||||
"size": "512x768"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "make it blue",
|
||||
Images: []api.ImageData{decodedImage},
|
||||
Width: 512,
|
||||
Height: 768,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image edit missing prompt",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"image": "` + testImage + `"
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "prompt is required",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image edit missing model",
|
||||
body: `{
|
||||
"prompt": "make it blue",
|
||||
"image": "` + testImage + `"
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "model is required",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image edit missing image",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "make it blue"
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "image is required",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(ImageEditsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
defer func() { capturedRequest = nil }()
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if tc.err.Error.Message != "" {
|
||||
var errResp openai.ErrorResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.err, errResp); diff != "" {
|
||||
t.Fatalf("errors did not match:\n%s", diff)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
|
||||
t.Fatalf("requests did not match:\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,7 +66,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
|
||||
@@ -80,7 +81,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
|
||||
@@ -89,7 +91,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
}
|
||||
|
||||
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false);
|
||||
@@ -397,7 +400,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
constexpr int cols_per_warp = T_B_KQ::I;
|
||||
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
|
||||
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
|
||||
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
|
||||
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
|
||||
@@ -467,7 +470,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
}
|
||||
}
|
||||
} else {
|
||||
static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
|
||||
load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
|
||||
@@ -479,8 +481,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
T_A_KQ K_A;
|
||||
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
|
||||
|
||||
// Wide version of KQ_C is column-major => swap A and B.
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
|
||||
if constexpr (cols_per_warp == 8) {
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
||||
} else {
|
||||
// Wide version of KQ_C is column-major
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
// RDNA matrix C is column-major.
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
||||
#else
|
||||
// swap A and B for CUDA.
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -841,7 +853,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
|
||||
constexpr int cols_per_warp = T_B_KQ::I;
|
||||
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
|
||||
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
|
||||
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
|
||||
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
|
||||
@@ -1353,6 +1365,13 @@ static __global__ void flash_attn_ext_f16(
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
#ifdef VOLTA_MMA_AVAILABLE
|
||||
if (ncols1*ncols2 < 32) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
#endif // VOLTA_MMA_AVAILABLE
|
||||
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
if (ncols1*ncols2 > 32) {
|
||||
NO_DEVICE_CODE;
|
||||
@@ -1585,3 +1604,8 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
|
||||
|
||||
// For GLM 4.7 Flash
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
||||
|
||||
@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
|
||||
return 0;
|
||||
@@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
|
||||
|
||||
return 0;
|
||||
@@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
|
||||
|
||||
@@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)
|
||||
|
||||
@@ -1187,6 +1195,14 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
if (use_gqa_opt && gqa_ratio % 4 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (DV <= 256) {
|
||||
|
||||
12
ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu
vendored
12
ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu
vendored
@@ -111,7 +111,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
|
||||
break;
|
||||
case 576: {
|
||||
// For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
|
||||
// For Deepseek/GLM4, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
|
||||
GGML_ASSERT(V->ne[0] == 512);
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
@@ -121,8 +121,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
GGML_ASSERT(gqa_ratio % 16 == 0);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
GGML_ASSERT(gqa_ratio % 4 == 0);
|
||||
if (gqa_ratio % 16 == 0) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -251,7 +255,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
if (V->ne[0] != 512) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
|
||||
if (!gqa_opt_applies || gqa_ratio % 4 != 0) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
||||
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
|
||||
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
||||
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
||||
|
||||
@@ -1071,12 +1071,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
op->src[0]->ne[0] != 112 &&
|
||||
op->src[0]->ne[0] != 128 &&
|
||||
op->src[0]->ne[0] != 192 &&
|
||||
op->src[0]->ne[0] != 256) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->ne[0] == 576) {
|
||||
// DeepSeek sizes
|
||||
// TODO: disabled for now, until optmized
|
||||
op->src[0]->ne[0] != 256 &&
|
||||
op->src[0]->ne[0] != 576) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[1]->type != op->src[2]->type) {
|
||||
|
||||
@@ -8967,6 +8967,7 @@ kernel void kernel_flash_attn_ext(
|
||||
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
|
||||
//case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
|
||||
case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
|
||||
case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
|
||||
}
|
||||
#undef FWD_TMPL
|
||||
#undef FWD_ARGS
|
||||
|
||||
@@ -2456,7 +2456,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
// simdgroups per threadgroup (a.k.a. warps)
|
||||
//nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
|
||||
int32_t nsg = 4;
|
||||
int32_t nsg = ne00 >= 512 ? 8 : 4;
|
||||
|
||||
const size_t smem = FATTN_SMEM(nsg);
|
||||
|
||||
|
||||
@@ -6166,6 +6166,7 @@ kernel void kernel_flash_attn_ext(
|
||||
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
|
||||
//case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
|
||||
case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
|
||||
case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
|
||||
}
|
||||
#undef FWD_TMPL
|
||||
#undef FWD_ARGS
|
||||
|
||||
@@ -39,6 +39,13 @@ type Model interface {
|
||||
Config() config
|
||||
}
|
||||
|
||||
// Validator is an optional interface that models can implement to perform
|
||||
// validation after tensors have been loaded. If validation fails, model
|
||||
// loading will fail with the returned error.
|
||||
type Validator interface {
|
||||
Validate() error
|
||||
}
|
||||
|
||||
// MultimodalProcessor must be implemented by multimodal models.
|
||||
type MultimodalProcessor interface {
|
||||
// EncodeMultimodal processes a single input (such as an image) and
|
||||
@@ -116,6 +123,13 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
|
||||
base := Base{b: b, config: m.Config()}
|
||||
v := reflect.ValueOf(m)
|
||||
v.Elem().Set(populateFields(base, v.Elem()))
|
||||
|
||||
if validator, ok := m.(Validator); ok {
|
||||
if err := validator.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package glm4moelite
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
@@ -11,6 +12,8 @@ import (
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
var ErrOldModelFormat = errors.New("this model uses a weight format that is no longer supported; please re-download it")
|
||||
|
||||
type Options struct {
|
||||
numExpertsUsed int
|
||||
numExperts int
|
||||
@@ -47,7 +50,9 @@ type Attention struct {
|
||||
|
||||
KVA *nn.Linear `gguf:"attn_kv_a_mqa"`
|
||||
KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"`
|
||||
KVB *nn.Linear `gguf:"attn_kv_b"`
|
||||
|
||||
KB *nn.Linear `gguf:"attn_k_b"`
|
||||
VB *nn.Linear `gguf:"attn_v_b"`
|
||||
|
||||
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
|
||||
}
|
||||
@@ -78,15 +83,16 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
|
||||
qRot := opts.applyRotaryPositionEmbeddings(ctx, queryChunks[1], positions)
|
||||
kRot = opts.applyRotaryPositionEmbeddings(ctx, kRot, positions)
|
||||
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
|
||||
kPass = attn.KVB.Forward(ctx, kPass)
|
||||
|
||||
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
|
||||
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
|
||||
// MLA absorption: absorb K projection into query
|
||||
qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3)
|
||||
qPassAbsorb := attn.KB.Forward(ctx, qPass).Permute(ctx, 0, 2, 1, 3)
|
||||
query = qRot.Concat(ctx, qPassAbsorb, 0)
|
||||
|
||||
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
|
||||
query = qRot.Concat(ctx, queryChunks[0], 0)
|
||||
key := kRot.Concat(ctx, kvChunks[0], 0)
|
||||
attention := nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
|
||||
kPass = kPass.Reshape(ctx, opts.kvLoraRank, 1, seqLength)
|
||||
key := kRot.Concat(ctx, kPass, 0)
|
||||
|
||||
attention := nn.AttentionWithVMLA(ctx, query, key, kPass, nil, attn.VB.Weight, opts.kqScale, cache)
|
||||
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
|
||||
return attn.Output.Forward(ctx, attention)
|
||||
@@ -217,7 +223,6 @@ func New(c fs.Config) (model.Model, error) {
|
||||
|
||||
keyLength := int(c.Uint("attention.key_length"))
|
||||
valueLength := int(c.Uint("attention.value_length"))
|
||||
|
||||
kqScale := 1.0 / math.Sqrt(float64(keyLength))
|
||||
|
||||
var pre []string
|
||||
@@ -236,7 +241,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
@@ -279,6 +284,15 @@ func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||
}
|
||||
|
||||
func (m *Model) Validate() error {
|
||||
for _, layer := range m.Layers {
|
||||
if layer.Attention != nil && (layer.Attention.KB == nil || layer.Attention.VB == nil) {
|
||||
return ErrOldModelFormat
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
|
||||
|
||||
73
model/models/glm4moelite/model_test.go
Normal file
73
model/models/glm4moelite/model_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package glm4moelite
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
func TestValidate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model *Model
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid model with KB and VB",
|
||||
model: &Model{
|
||||
Layers: []Layer{
|
||||
{Attention: &Attention{KB: &nn.Linear{}, VB: &nn.Linear{}}},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing KB",
|
||||
model: &Model{
|
||||
Layers: []Layer{
|
||||
{Attention: &Attention{VB: &nn.Linear{}}},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing VB",
|
||||
model: &Model{
|
||||
Layers: []Layer{
|
||||
{Attention: &Attention{KB: &nn.Linear{}}},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing both KB and VB",
|
||||
model: &Model{
|
||||
Layers: []Layer{
|
||||
{Attention: &Attention{}},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "nil Attention is ok",
|
||||
model: &Model{
|
||||
Layers: []Layer{
|
||||
{Attention: nil},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.model.Validate()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if tt.wantErr && err != ErrOldModelFormat {
|
||||
t.Errorf("Validate() error = %v, want %v", err, ErrOldModelFormat)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -794,3 +794,47 @@ func ToImageGenerationResponse(resp api.GenerateResponse) ImageGenerationRespons
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
// ImageEditRequest is an OpenAI-compatible image edit request.
|
||||
type ImageEditRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Image string `json:"image"` // Base64-encoded image data
|
||||
Size string `json:"size,omitempty"` // e.g., "1024x1024"
|
||||
Seed *int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
// FromImageEditRequest converts an OpenAI image edit request to an Ollama GenerateRequest.
|
||||
func FromImageEditRequest(r ImageEditRequest) (api.GenerateRequest, error) {
|
||||
req := api.GenerateRequest{
|
||||
Model: r.Model,
|
||||
Prompt: r.Prompt,
|
||||
}
|
||||
|
||||
// Decode the input image
|
||||
if r.Image != "" {
|
||||
imgData, err := decodeImageURL(r.Image)
|
||||
if err != nil {
|
||||
return api.GenerateRequest{}, fmt.Errorf("invalid image: %w", err)
|
||||
}
|
||||
req.Images = append(req.Images, imgData)
|
||||
}
|
||||
|
||||
// Parse size if provided (e.g., "1024x768")
|
||||
if r.Size != "" {
|
||||
var w, h int32
|
||||
if _, err := fmt.Sscanf(r.Size, "%dx%d", &w, &h); err == nil {
|
||||
req.Width = w
|
||||
req.Height = h
|
||||
}
|
||||
}
|
||||
|
||||
if r.Seed != nil {
|
||||
if req.Options == nil {
|
||||
req.Options = map[string]any{}
|
||||
}
|
||||
req.Options["seed"] = *r.Seed
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
@@ -448,3 +448,86 @@ func TestFromChatRequest_TopLogprobsRange(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromImageEditRequest_Basic(t *testing.T) {
|
||||
req := ImageEditRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "make it blue",
|
||||
Image: prefix + image,
|
||||
}
|
||||
|
||||
result, err := FromImageEditRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Model != "test-model" {
|
||||
t.Errorf("expected model 'test-model', got %q", result.Model)
|
||||
}
|
||||
|
||||
if result.Prompt != "make it blue" {
|
||||
t.Errorf("expected prompt 'make it blue', got %q", result.Prompt)
|
||||
}
|
||||
|
||||
if len(result.Images) != 1 {
|
||||
t.Fatalf("expected 1 image, got %d", len(result.Images))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromImageEditRequest_WithSize(t *testing.T) {
|
||||
req := ImageEditRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "make it blue",
|
||||
Image: prefix + image,
|
||||
Size: "512x768",
|
||||
}
|
||||
|
||||
result, err := FromImageEditRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Width != 512 {
|
||||
t.Errorf("expected width 512, got %d", result.Width)
|
||||
}
|
||||
|
||||
if result.Height != 768 {
|
||||
t.Errorf("expected height 768, got %d", result.Height)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromImageEditRequest_WithSeed(t *testing.T) {
|
||||
seed := int64(12345)
|
||||
req := ImageEditRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "make it blue",
|
||||
Image: prefix + image,
|
||||
Seed: &seed,
|
||||
}
|
||||
|
||||
result, err := FromImageEditRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Options == nil {
|
||||
t.Fatal("expected options to be set")
|
||||
}
|
||||
|
||||
if result.Options["seed"] != seed {
|
||||
t.Errorf("expected seed %d, got %v", seed, result.Options["seed"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromImageEditRequest_InvalidImage(t *testing.T) {
|
||||
req := ImageEditRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "make it blue",
|
||||
Image: "not-valid-base64",
|
||||
}
|
||||
|
||||
_, err := FromImageEditRequest(req)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid image")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,7 +95,21 @@ func (i *Instance) Readline() (string, error) {
|
||||
|
||||
var currentLineBuf []rune
|
||||
|
||||
// draining tracks if we're processing buffered input from cooked mode.
|
||||
// In cooked mode Enter sends \n, but in raw mode Ctrl+J sends \n.
|
||||
// We treat \n from cooked mode as submit, not multiline.
|
||||
// We check Buffered() after the first read since the bufio buffer is
|
||||
// empty until then. This is compatible with """ multiline mode in
|
||||
// interactive.go since each Readline() call is independent.
|
||||
var draining, stopDraining bool
|
||||
|
||||
for {
|
||||
// Apply deferred state change from previous iteration
|
||||
if stopDraining {
|
||||
draining = false
|
||||
stopDraining = false
|
||||
}
|
||||
|
||||
// don't show placeholder when pasting unless we're in multiline mode
|
||||
showPlaceholder := !i.Pasting || i.Prompt.UseAlt
|
||||
if buf.IsEmpty() && showPlaceholder {
|
||||
@@ -105,6 +119,15 @@ func (i *Instance) Readline() (string, error) {
|
||||
|
||||
r, err := i.Terminal.Read()
|
||||
|
||||
// After reading, check if there's more buffered data. If so, we're
|
||||
// processing cooked-mode input. Once buffer empties, the current
|
||||
// char is the last buffered one (still drain it), then stop next iteration.
|
||||
if i.Terminal.reader.Buffered() > 0 {
|
||||
draining = true
|
||||
} else if draining {
|
||||
stopDraining = true
|
||||
}
|
||||
|
||||
if buf.IsEmpty() {
|
||||
fmt.Print(ClearToEOL)
|
||||
}
|
||||
@@ -232,15 +255,20 @@ func (i *Instance) Readline() (string, error) {
|
||||
fd := os.Stdin.Fd()
|
||||
return handleCharCtrlZ(fd, i.Terminal.termios)
|
||||
case CharCtrlJ:
|
||||
i.pastedLines = append(i.pastedLines, buf.String())
|
||||
buf.Buf.Clear()
|
||||
buf.Pos = 0
|
||||
buf.DisplayPos = 0
|
||||
buf.LineHasSpace.Clear()
|
||||
fmt.Println()
|
||||
fmt.Print(i.Prompt.AltPrompt)
|
||||
i.Prompt.UseAlt = true
|
||||
continue
|
||||
// If not draining cooked-mode input, treat as multiline
|
||||
if !draining {
|
||||
i.pastedLines = append(i.pastedLines, buf.String())
|
||||
buf.Buf.Clear()
|
||||
buf.Pos = 0
|
||||
buf.DisplayPos = 0
|
||||
buf.LineHasSpace.Clear()
|
||||
fmt.Println()
|
||||
fmt.Print(i.Prompt.AltPrompt)
|
||||
i.Prompt.UseAlt = true
|
||||
continue
|
||||
}
|
||||
// Draining cooked-mode input: treat \n as submit
|
||||
fallthrough
|
||||
case CharEnter:
|
||||
output := buf.String()
|
||||
if len(i.pastedLines) > 0 {
|
||||
|
||||
@@ -3,7 +3,7 @@ package runner
|
||||
import (
|
||||
"github.com/ollama/ollama/runner/llamarunner"
|
||||
"github.com/ollama/ollama/runner/ollamarunner"
|
||||
imagerunner "github.com/ollama/ollama/x/imagegen/runner"
|
||||
"github.com/ollama/ollama/x/mlxrunner"
|
||||
)
|
||||
|
||||
func Execute(args []string) error {
|
||||
@@ -12,18 +12,18 @@ func Execute(args []string) error {
|
||||
}
|
||||
|
||||
var newRunner bool
|
||||
var imageRunner bool
|
||||
var mlxRunner bool
|
||||
if len(args) > 0 && args[0] == "--ollama-engine" {
|
||||
args = args[1:]
|
||||
newRunner = true
|
||||
}
|
||||
if len(args) > 0 && args[0] == "--image-engine" {
|
||||
if len(args) > 0 && args[0] == "--mlx-engine" {
|
||||
args = args[1:]
|
||||
imageRunner = true
|
||||
mlxRunner = true
|
||||
}
|
||||
|
||||
if imageRunner {
|
||||
return imagerunner.Execute(args)
|
||||
if mlxRunner {
|
||||
return mlxrunner.Execute(args)
|
||||
} else if newRunner {
|
||||
return ollamarunner.Execute(args)
|
||||
} else {
|
||||
|
||||
@@ -14,8 +14,8 @@
|
||||
VOL_NAME=${VOL_NAME:-"Ollama"}
|
||||
export VERSION=${VERSION:-$(git describe --tags --first-parent --abbrev=7 --long --dirty --always | sed -e "s/^v//g")}
|
||||
export GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${VERSION#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'"
|
||||
export CGO_CFLAGS="-mmacosx-version-min=14.0"
|
||||
export CGO_CXXFLAGS="-mmacosx-version-min=14.0"
|
||||
export CGO_CFLAGS="-O3 -mmacosx-version-min=14.0"
|
||||
export CGO_CXXFLAGS="-O3 -mmacosx-version-min=14.0"
|
||||
export CGO_LDFLAGS="-mmacosx-version-min=14.0"
|
||||
|
||||
set -e
|
||||
|
||||
@@ -56,6 +56,12 @@ function checkEnv {
|
||||
|
||||
$script:DIST_DIR="${script:SRC_DIR}\dist\windows-${script:TARGET_ARCH}"
|
||||
$env:CGO_ENABLED="1"
|
||||
if (-not $env:CGO_CFLAGS) {
|
||||
$env:CGO_CFLAGS = "-O3"
|
||||
}
|
||||
if (-not $env:CGO_CXXFLAGS) {
|
||||
$env:CGO_CXXFLAGS = "-O3"
|
||||
}
|
||||
Write-Output "Checking version"
|
||||
if (!$env:VERSION) {
|
||||
$data=(git describe --tags --first-parent --abbrev=7 --long --dirty --always)
|
||||
|
||||
@@ -75,12 +75,6 @@ type Model struct {
|
||||
func (m *Model) Capabilities() []model.Capability {
|
||||
capabilities := []model.Capability{}
|
||||
|
||||
// Check for image generation model via config capabilities
|
||||
if slices.Contains(m.Config.Capabilities, "image") {
|
||||
return []model.Capability{model.CapabilityImage}
|
||||
}
|
||||
|
||||
// Check for completion capability
|
||||
if m.ModelPath != "" {
|
||||
f, err := gguf.Open(m.ModelPath)
|
||||
if err == nil {
|
||||
|
||||
@@ -56,6 +56,15 @@ func TestModelCapabilities(t *testing.T) {
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityImage},
|
||||
},
|
||||
{
|
||||
name: "model with image and vision capability (image editing)",
|
||||
model: Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image", "vision"},
|
||||
},
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityImage, model.CapabilityVision},
|
||||
},
|
||||
{
|
||||
name: "model with completion capability",
|
||||
model: Model{
|
||||
|
||||
@@ -95,6 +95,13 @@ func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType
|
||||
// for the 8-expert model, bumping this to Q8_0 trades just ~128MB
|
||||
newType = fsggml.TensorTypeQ8_0
|
||||
}
|
||||
} else if strings.Contains(name, "attn_k_b.weight") ||
|
||||
strings.Contains(name, "attn_v_b.weight") ||
|
||||
strings.Contains(name, "attn_kv_a_mqa.weight") ||
|
||||
strings.Contains(name, "attn_q_a.weight") ||
|
||||
strings.Contains(name, "attn_q_b.weight") {
|
||||
// MLA tensors need higher precision to avoid quality degradation
|
||||
newType = fsggml.TensorTypeQ8_0
|
||||
} else if strings.Contains(name, "ffn_down") {
|
||||
iLayer := qs.iFfnDown
|
||||
n_layer := qs.nFfnDown
|
||||
|
||||
@@ -1604,8 +1604,9 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
// OpenAI-compatible image generation endpoint
|
||||
// OpenAI-compatible image generation endpoints
|
||||
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/images/edits", middleware.ImageEditsMiddleware(), s.GenerateHandler)
|
||||
|
||||
// Inference (Anthropic compatibility)
|
||||
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||
@@ -2507,8 +2508,14 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
|
||||
return
|
||||
}
|
||||
|
||||
// Set headers for streaming response
|
||||
c.Header("Content-Type", "application/x-ndjson")
|
||||
// Check streaming preference
|
||||
isStreaming := req.Stream == nil || *req.Stream
|
||||
|
||||
contentType := "application/x-ndjson"
|
||||
if !isStreaming {
|
||||
contentType = "application/json; charset=utf-8"
|
||||
}
|
||||
c.Header("Content-Type", contentType)
|
||||
|
||||
// Get seed from options if provided
|
||||
var seed int64
|
||||
@@ -2523,13 +2530,21 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
|
||||
}
|
||||
}
|
||||
|
||||
var images []llm.ImageData
|
||||
for i, imgData := range req.Images {
|
||||
images = append(images, llm.ImageData{ID: i, Data: imgData})
|
||||
}
|
||||
|
||||
var streamStarted bool
|
||||
var finalResponse api.GenerateResponse
|
||||
|
||||
if err := runner.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: req.Prompt,
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
Steps: req.Steps,
|
||||
Seed: seed,
|
||||
Images: images,
|
||||
}, func(cr llm.CompletionResponse) {
|
||||
streamStarted = true
|
||||
res := api.GenerateResponse{
|
||||
@@ -2553,6 +2568,11 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
|
||||
res.Metrics.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
}
|
||||
|
||||
if !isStreaming {
|
||||
finalResponse = res
|
||||
return
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(res)
|
||||
c.Writer.Write(append(data, '\n'))
|
||||
c.Writer.Flush()
|
||||
@@ -2562,5 +2582,10 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
|
||||
if !streamStarted {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if !isStreaming {
|
||||
c.JSON(http.StatusOK, finalResponse)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,7 +19,9 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests)
|
||||
@@ -71,6 +73,8 @@ func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error
|
||||
return
|
||||
}
|
||||
|
||||
func (mockRunner) Ping(_ context.Context) error { return nil }
|
||||
|
||||
func newMockServer(mock *mockRunner) func(ml.SystemInfo, []ml.DeviceInfo, string, *ggml.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
|
||||
return func(_ ml.SystemInfo, _ []ml.DeviceInfo, _ string, _ *ggml.GGML, _, _ []string, _ api.Options, _ int) (llm.LlamaServer, error) {
|
||||
return mock, nil
|
||||
@@ -2193,3 +2197,246 @@ func TestGenerateUnload(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateWithImages(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mock := mockRunner{
|
||||
CompletionResponse: llm.CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: llm.DoneReasonStop,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
},
|
||||
}
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
waitForRecovery: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.block_count": uint32(1),
|
||||
"llama.context_length": uint32(8192),
|
||||
"llama.embedding_length": uint32(4096),
|
||||
"llama.attention.head_count": uint32(32),
|
||||
"llama.attention.head_count_kv": uint32(8),
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []*ggml.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
})
|
||||
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
t.Run("images passed to completion request", func(t *testing.T) {
|
||||
testImage := []byte("test-image-data")
|
||||
|
||||
mock.CompletionResponse.Content = "Image processed"
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test",
|
||||
Prompt: "Describe this image",
|
||||
Images: []api.ImageData{testImage},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Verify images were passed to the completion request
|
||||
if len(mock.CompletionRequest.Images) != 1 {
|
||||
t.Fatalf("expected 1 image in completion request, got %d", len(mock.CompletionRequest.Images))
|
||||
}
|
||||
|
||||
if !bytes.Equal(mock.CompletionRequest.Images[0].Data, testImage) {
|
||||
t.Errorf("image data mismatch in completion request")
|
||||
}
|
||||
|
||||
if mock.CompletionRequest.Images[0].ID != 0 {
|
||||
t.Errorf("expected image ID 0, got %d", mock.CompletionRequest.Images[0].ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multiple images passed to completion request", func(t *testing.T) {
|
||||
testImage1 := []byte("test-image-1")
|
||||
testImage2 := []byte("test-image-2")
|
||||
|
||||
mock.CompletionResponse.Content = "Images processed"
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test",
|
||||
Prompt: "Compare these images",
|
||||
Images: []api.ImageData{testImage1, testImage2},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Verify both images were passed
|
||||
if len(mock.CompletionRequest.Images) != 2 {
|
||||
t.Fatalf("expected 2 images in completion request, got %d", len(mock.CompletionRequest.Images))
|
||||
}
|
||||
|
||||
if !bytes.Equal(mock.CompletionRequest.Images[0].Data, testImage1) {
|
||||
t.Errorf("first image data mismatch")
|
||||
}
|
||||
|
||||
if !bytes.Equal(mock.CompletionRequest.Images[1].Data, testImage2) {
|
||||
t.Errorf("second image data mismatch")
|
||||
}
|
||||
|
||||
if mock.CompletionRequest.Images[0].ID != 0 || mock.CompletionRequest.Images[1].ID != 1 {
|
||||
t.Errorf("expected image IDs 0 and 1, got %d and %d",
|
||||
mock.CompletionRequest.Images[0].ID, mock.CompletionRequest.Images[1].ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no images when none provided", func(t *testing.T) {
|
||||
mock.CompletionResponse.Content = "No images"
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test",
|
||||
Prompt: "Hello",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Verify no images in completion request
|
||||
if len(mock.CompletionRequest.Images) != 0 {
|
||||
t.Fatalf("expected 0 images in completion request, got %d", len(mock.CompletionRequest.Images))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestImageGenerateStreamFalse tests that image generation respects stream=false
|
||||
// and returns a single JSON response instead of streaming ndjson.
|
||||
func TestImageGenerateStreamFalse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
|
||||
mock := mockRunner{}
|
||||
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
fn(llm.CompletionResponse{Step: 1, TotalSteps: 3, Done: false})
|
||||
fn(llm.CompletionResponse{Step: 2, TotalSteps: 3, Done: false})
|
||||
fn(llm.CompletionResponse{Step: 3, TotalSteps: 3, Done: true, DoneReason: llm.DoneReasonStop, Image: "base64image"})
|
||||
return nil
|
||||
}
|
||||
|
||||
opts := api.DefaultOptions()
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: map[string]*runnerRef{
|
||||
"": {
|
||||
llama: &mock,
|
||||
Options: &opts,
|
||||
model: &Model{Config: model.ConfigV2{Capabilities: []string{"image"}}},
|
||||
numParallel: 1,
|
||||
},
|
||||
},
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
// Create model manifest with image capability
|
||||
n := model.ParseName("test-image")
|
||||
cfg := model.ConfigV2{Capabilities: []string{"image"}}
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(&cfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
configLayer, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := manifest.WriteManifest(n, configLayer, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
streamFalse := false
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-image",
|
||||
Prompt: "test prompt",
|
||||
Stream: &streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
if ct := w.Header().Get("Content-Type"); ct != "application/json; charset=utf-8" {
|
||||
t.Errorf("expected Content-Type 'application/json; charset=utf-8', got %q", ct)
|
||||
}
|
||||
|
||||
body := w.Body.String()
|
||||
lines := strings.Split(strings.TrimSpace(body), "\n")
|
||||
if len(lines) != 1 {
|
||||
t.Errorf("expected 1 response line, got %d:\n%s", len(lines), body)
|
||||
}
|
||||
|
||||
var resp api.GenerateResponse
|
||||
if err := json.Unmarshal([]byte(lines[0]), &resp); err != nil {
|
||||
t.Fatalf("failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if resp.Image != "base64image" {
|
||||
t.Errorf("expected image 'base64image', got %q", resp.Image)
|
||||
}
|
||||
|
||||
if !resp.Done {
|
||||
t.Errorf("expected done=true")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ import (
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/mlxrunner"
|
||||
)
|
||||
|
||||
type LlmRequest struct {
|
||||
@@ -195,14 +195,25 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus))
|
||||
}
|
||||
|
||||
// Check for image generation model before attempting GGML load
|
||||
// Check for image generation models - all use MLX runner
|
||||
if slices.Contains(pending.model.Config.Capabilities, "image") {
|
||||
if s.loadImageGen(pending) {
|
||||
if s.loadMLX(pending) {
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for experimental safetensors LLM models
|
||||
if pending.model.Config.ModelFormat == "safetensors" {
|
||||
if slices.Contains(pending.model.Config.Capabilities, "completion") {
|
||||
// LLM model with safetensors format - use MLX runner
|
||||
if s.loadMLX(pending) {
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Load model for fitting
|
||||
logutil.Trace("loading model metadata", "model", pending.model.ModelPath)
|
||||
ggml, err := llm.LoadModel(pending.model.ModelPath, 1024)
|
||||
@@ -552,11 +563,20 @@ iGPUScan:
|
||||
return false
|
||||
}
|
||||
|
||||
// loadImageGen loads an image generation model.
|
||||
func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
|
||||
// Use model name for imagegen (it resolves manifests by name, not file path)
|
||||
// loadMLX loads an experimental safetensors model using the unified MLX runner.
|
||||
// This supports both LLM (completion) and image generation models.
|
||||
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||
// Determine mode based on capabilities
|
||||
var mode mlxrunner.ModelMode
|
||||
if slices.Contains(req.model.Config.Capabilities, "image") {
|
||||
mode = mlxrunner.ModeImageGen
|
||||
} else {
|
||||
mode = mlxrunner.ModeLLM
|
||||
}
|
||||
|
||||
// Use model name for MLX (it resolves manifests by name, not file path)
|
||||
modelName := req.model.ShortName
|
||||
server, err := imagegen.NewServer(modelName)
|
||||
server, err := mlxrunner.NewServer(modelName, mode)
|
||||
if err != nil {
|
||||
req.errCh <- err
|
||||
return true
|
||||
|
||||
@@ -11,6 +11,9 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/progress"
|
||||
@@ -32,7 +35,7 @@ type ModelfileConfig struct {
|
||||
type CreateOptions struct {
|
||||
ModelName string
|
||||
ModelDir string
|
||||
Quantize string // "fp8" for quantization
|
||||
Quantize string // "q4", "q8", "nvfp4", or "mxfp8" for quantization
|
||||
Modelfile *ModelfileConfig // template/system/license from Modelfile
|
||||
}
|
||||
|
||||
@@ -51,10 +54,20 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
||||
// Determine model type settings
|
||||
var modelType, spinnerKey string
|
||||
var capabilities []string
|
||||
var parserName, rendererName string
|
||||
if isSafetensors {
|
||||
modelType = "safetensors model"
|
||||
spinnerKey = "create"
|
||||
capabilities = []string{"completion"}
|
||||
|
||||
// Check if model supports thinking based on architecture
|
||||
if supportsThinking(opts.ModelDir) {
|
||||
capabilities = append(capabilities, "thinking")
|
||||
}
|
||||
|
||||
// Set parser and renderer name based on architecture
|
||||
parserName = getParserName(opts.ModelDir)
|
||||
rendererName = getRendererName(opts.ModelDir)
|
||||
} else {
|
||||
modelType = "image generation model"
|
||||
spinnerKey = "imagegen"
|
||||
@@ -79,14 +92,14 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
||||
err = create.CreateSafetensorsModel(
|
||||
opts.ModelName, opts.ModelDir, opts.Quantize,
|
||||
newLayerCreator(), newTensorLayerCreator(),
|
||||
newManifestWriter(opts, capabilities),
|
||||
newManifestWriter(opts, capabilities, parserName, rendererName),
|
||||
progressFn,
|
||||
)
|
||||
} else {
|
||||
err = create.CreateImageGenModel(
|
||||
opts.ModelName, opts.ModelDir, opts.Quantize,
|
||||
newLayerCreator(), newTensorLayerCreator(),
|
||||
newManifestWriter(opts, capabilities),
|
||||
newManifestWriter(opts, capabilities, "", ""),
|
||||
progressFn,
|
||||
)
|
||||
}
|
||||
@@ -202,18 +215,33 @@ func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error
|
||||
}
|
||||
|
||||
// newManifestWriter returns a ManifestWriter callback for writing the model manifest.
|
||||
func newManifestWriter(opts CreateOptions, capabilities []string) create.ManifestWriter {
|
||||
func newManifestWriter(opts CreateOptions, capabilities []string, parserName, rendererName string) create.ManifestWriter {
|
||||
return func(modelName string, config create.LayerInfo, layers []create.LayerInfo) error {
|
||||
name := model.ParseName(modelName)
|
||||
if !name.IsValid() {
|
||||
return fmt.Errorf("invalid model name: %s", modelName)
|
||||
}
|
||||
|
||||
// TODO: find a better way to detect image input support
|
||||
// For now, hardcode Flux2KleinPipeline as supporting vision (image input)
|
||||
caps := capabilities
|
||||
modelIndex := filepath.Join(opts.ModelDir, "model_index.json")
|
||||
if data, err := os.ReadFile(modelIndex); err == nil {
|
||||
var cfg struct {
|
||||
ClassName string `json:"_class_name"`
|
||||
}
|
||||
if json.Unmarshal(data, &cfg) == nil && cfg.ClassName == "Flux2KleinPipeline" {
|
||||
caps = append(caps, "vision")
|
||||
}
|
||||
}
|
||||
|
||||
// Create config blob with version requirement
|
||||
configData := model.ConfigV2{
|
||||
ModelFormat: "safetensors",
|
||||
Capabilities: capabilities,
|
||||
Capabilities: caps,
|
||||
Requires: MinOllamaVersion,
|
||||
Parser: parserName,
|
||||
Renderer: rendererName,
|
||||
}
|
||||
configJSON, err := json.Marshal(configData)
|
||||
if err != nil {
|
||||
@@ -280,3 +308,146 @@ func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
|
||||
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
// supportsThinking checks if the model supports thinking mode based on its architecture.
|
||||
// This reads the config.json from the model directory and checks the architectures field.
|
||||
func supportsThinking(modelDir string) bool {
|
||||
configPath := filepath.Join(modelDir, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check architectures that support thinking
|
||||
thinkingArchitectures := []string{
|
||||
"glm4moe", // GLM-4 MoE models
|
||||
"deepseek", // DeepSeek models
|
||||
"qwen3", // Qwen3 models
|
||||
}
|
||||
|
||||
// Check the architecture list
|
||||
for _, arch := range cfg.Architectures {
|
||||
archLower := strings.ToLower(arch)
|
||||
for _, thinkArch := range thinkingArchitectures {
|
||||
if strings.Contains(archLower, thinkArch) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Also check model_type
|
||||
if cfg.ModelType != "" {
|
||||
typeLower := strings.ToLower(cfg.ModelType)
|
||||
for _, thinkArch := range thinkingArchitectures {
|
||||
if strings.Contains(typeLower, thinkArch) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// getParserName returns the parser name for a model based on its architecture.
|
||||
// This reads the config.json from the model directory and determines the appropriate parser.
|
||||
func getParserName(modelDir string) string {
|
||||
configPath := filepath.Join(modelDir, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Check architectures for known parsers
|
||||
for _, arch := range cfg.Architectures {
|
||||
archLower := strings.ToLower(arch)
|
||||
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
|
||||
return "glm-4.7"
|
||||
}
|
||||
if strings.Contains(archLower, "deepseek") {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(archLower, "qwen3") {
|
||||
return "qwen3-coder"
|
||||
}
|
||||
}
|
||||
|
||||
// Also check model_type
|
||||
if cfg.ModelType != "" {
|
||||
typeLower := strings.ToLower(cfg.ModelType)
|
||||
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
|
||||
return "glm-4.7"
|
||||
}
|
||||
if strings.Contains(typeLower, "deepseek") {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(typeLower, "qwen3") {
|
||||
return "qwen3-coder"
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// getRendererName returns the renderer name for a model based on its architecture.
|
||||
// This reads the config.json from the model directory and determines the appropriate renderer.
|
||||
func getRendererName(modelDir string) string {
|
||||
configPath := filepath.Join(modelDir, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Check architectures for known renderers
|
||||
for _, arch := range cfg.Architectures {
|
||||
archLower := strings.ToLower(arch)
|
||||
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
|
||||
return "glm-4.7"
|
||||
}
|
||||
if strings.Contains(archLower, "deepseek") {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(archLower, "qwen3") {
|
||||
return "qwen3-coder"
|
||||
}
|
||||
}
|
||||
|
||||
// Also check model_type
|
||||
if cfg.ModelType != "" {
|
||||
typeLower := strings.ToLower(cfg.ModelType)
|
||||
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
|
||||
return "glm-4.7"
|
||||
}
|
||||
if strings.Contains(typeLower, "deepseek") {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(typeLower, "qwen3") {
|
||||
return "qwen3-coder"
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -13,7 +13,11 @@ import (
|
||||
|
||||
// quantizeTensor loads a tensor from safetensors format, quantizes it,
|
||||
// and returns safetensors data for the quantized weights, scales, and biases.
|
||||
// Supported quantization types: "fp8" (affine 8-bit)
|
||||
// Supported quantization types:
|
||||
// - "q4": affine 4-bit, group_size=32 (with qbiases)
|
||||
// - "nvfp4": NVIDIA FP4, group_size=16 (no qbiases, E4M3 scales)
|
||||
// - "q8": affine 8-bit, group_size=64 (with qbiases)
|
||||
// - "mxfp8": Microsoft MX FP8, group_size=32 (no qbiases, E4M3 scales)
|
||||
// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
|
||||
func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
|
||||
tmpDir := ensureTempDir()
|
||||
@@ -54,12 +58,18 @@ func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize str
|
||||
// Quantize based on quantization type
|
||||
var qweight, scales, qbiases *mlx.Array
|
||||
switch quantize {
|
||||
case "fp4":
|
||||
// affine mode: group_size=32, bits=4
|
||||
case "q4":
|
||||
// affine mode: group_size=32, bits=4 (with qbiases for zero-point offset)
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 32, 4, "affine")
|
||||
case "fp8":
|
||||
// affine mode: group_size=32, bits=8
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "affine")
|
||||
case "nvfp4":
|
||||
// NVIDIA FP4: group_size=16, bits=4 (no qbiases, E4M3 scales)
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 16, 4, "nvfp4")
|
||||
case "q8":
|
||||
// affine mode: group_size=64, bits=8 (with qbiases for zero-point offset)
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 64, 8, "affine")
|
||||
case "mxfp8":
|
||||
// Microsoft MX FP8: group_size=32, bits=8, E4M3 scales (no qbiases)
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "mxfp8")
|
||||
default:
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
|
||||
}
|
||||
|
||||
@@ -228,7 +228,7 @@ type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
|
||||
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
|
||||
|
||||
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
|
||||
// When quantize is non-empty (e.g., "fp8"), returns multiple layers (weight + scales + biases).
|
||||
// When quantize is non-empty (e.g., "q8"), returns multiple layers (weight + scales + biases).
|
||||
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error)
|
||||
|
||||
// ManifestWriter writes the manifest file.
|
||||
@@ -262,36 +262,134 @@ func ShouldQuantize(name, component string) bool {
|
||||
return strings.HasSuffix(name, ".weight")
|
||||
}
|
||||
|
||||
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name and shape.
|
||||
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name, shape, and quantize type.
|
||||
// This is a more detailed check that also considers tensor dimensions.
|
||||
func ShouldQuantizeTensor(name string, shape []int32) bool {
|
||||
// The quantize parameter specifies the quantization type (e.g., "q4", "nvfp4", "q8", "mxfp8").
|
||||
func ShouldQuantizeTensor(name string, shape []int32, quantize string) bool {
|
||||
return GetTensorQuantization(name, shape, quantize) != ""
|
||||
}
|
||||
|
||||
// normalizeQuantType converts various quantization type aliases to canonical forms.
|
||||
// Supports: q4/Q4/int4/INT4/fp4/FP4 -> q4, q8/Q8/int8/INT8/fp8/FP8 -> q8, nvfp4/NVFP4, mxfp8/MXFP8
|
||||
func normalizeQuantType(quantize string) string {
|
||||
switch strings.ToUpper(quantize) {
|
||||
case "Q4", "INT4", "FP4":
|
||||
return "q4"
|
||||
case "Q8", "INT8", "FP8":
|
||||
return "q8"
|
||||
case "NVFP4":
|
||||
return "nvfp4"
|
||||
case "MXFP8":
|
||||
return "mxfp8"
|
||||
default:
|
||||
return quantize
|
||||
}
|
||||
}
|
||||
|
||||
// getQuantGroupSize returns the group size for a given quantization type.
|
||||
// These must match the values used in quantize.go when creating quantized models.
|
||||
func getQuantGroupSize(quantize string) int {
|
||||
switch normalizeQuantType(quantize) {
|
||||
case "nvfp4":
|
||||
return 16
|
||||
case "q4":
|
||||
return 32
|
||||
case "mxfp8":
|
||||
return 32
|
||||
case "q8":
|
||||
return 64
|
||||
default:
|
||||
return 32
|
||||
}
|
||||
}
|
||||
|
||||
// GetTensorQuantization returns the appropriate quantization type for a tensor.
|
||||
// Returns "" if the tensor should not be quantized.
|
||||
// This implements mixed-precision quantization:
|
||||
// - Attention MLA weights (q_a, q_b, kv_a, kv_b): unquantized (most sensitive)
|
||||
// - Output projection, gate/up weights: q4 (less sensitive)
|
||||
// - Down projection weights: q8 (more sensitive, would be Q6 in GGML but no MLX kernel)
|
||||
// - Norms, embeddings, biases, routing gates: no quantization
|
||||
func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
||||
// Use basic name-based check first
|
||||
if !ShouldQuantize(name, "") {
|
||||
return false
|
||||
return ""
|
||||
}
|
||||
|
||||
// Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any)
|
||||
if len(shape) != 2 {
|
||||
return false
|
||||
return ""
|
||||
}
|
||||
|
||||
// Skip small tensors (less than 1024 elements) - not worth quantizing
|
||||
if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 {
|
||||
return false
|
||||
return ""
|
||||
}
|
||||
|
||||
// MLX quantization requires last dimension to be divisible by group size (32)
|
||||
if shape[len(shape)-1]%32 != 0 {
|
||||
return false
|
||||
// Normalize quantization type to canonical form
|
||||
quantNorm := normalizeQuantType(quantize)
|
||||
|
||||
// MLX quantization requires last dimension to be divisible by group size
|
||||
// nvfp4: 16, q4/mxfp8: 32, q8: 64
|
||||
groupSize := int32(32)
|
||||
switch quantNorm {
|
||||
case "nvfp4":
|
||||
groupSize = 16
|
||||
case "q8":
|
||||
groupSize = 64
|
||||
}
|
||||
if shape[len(shape)-1]%groupSize != 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return true
|
||||
// Skip routing gate weights (should stay high precision)
|
||||
// In safetensors these are: mlp.gate.weight (not mlp.gate_proj.weight)
|
||||
if strings.Contains(name, "mlp.gate.weight") && !strings.Contains(name, "_proj") {
|
||||
return ""
|
||||
}
|
||||
|
||||
// For NVFP4 or MXFP8, use the same quantization for all (no mixed precision)
|
||||
if quantNorm == "nvfp4" || quantNorm == "mxfp8" {
|
||||
return quantNorm
|
||||
}
|
||||
|
||||
// Attention MLA weights - keep unquantized (bf16)
|
||||
// These are highly sensitive: errors accumulate in the KV cache over time
|
||||
// q_a_proj, q_b_proj, kv_a_proj_with_mqa, kv_b_proj
|
||||
if strings.Contains(name, "q_a_proj") ||
|
||||
strings.Contains(name, "q_b_proj") ||
|
||||
strings.Contains(name, "kv_a_proj") ||
|
||||
strings.Contains(name, "kv_b_proj") {
|
||||
return "" // No quantization - keep bf16
|
||||
}
|
||||
|
||||
// Down projection weights - use Q8 (would be Q6_K in GGML, but MLX has no Q6 kernel)
|
||||
// mlp.down_proj, mlp.experts.X.down_proj, mlp.shared_experts.down_proj
|
||||
if strings.Contains(name, "down_proj") {
|
||||
return "q8"
|
||||
}
|
||||
|
||||
// Output projection, gate/up weights - use requested quantization (Q4)
|
||||
// o_proj, gate_proj, up_proj
|
||||
if strings.Contains(name, "o_proj") ||
|
||||
strings.Contains(name, "gate_proj") ||
|
||||
strings.Contains(name, "up_proj") {
|
||||
return quantNorm
|
||||
}
|
||||
|
||||
// LM head - use requested quantization
|
||||
if strings.Contains(name, "lm_head") {
|
||||
return quantNorm
|
||||
}
|
||||
|
||||
// Default to requested quantization for other weights
|
||||
return quantNorm
|
||||
}
|
||||
|
||||
// CreateSafetensorsModel imports a standard safetensors model from a directory.
|
||||
// This handles Hugging Face style models with config.json and *.safetensors files.
|
||||
// Stores each tensor as a separate blob for fine-grained deduplication.
|
||||
// If quantize is non-empty (e.g., "fp8"), eligible tensors will be quantized.
|
||||
// If quantize is non-empty (e.g., "q8"), eligible tensors will be quantized.
|
||||
func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
var layers []LayerInfo
|
||||
var configLayer LayerInfo
|
||||
@@ -330,9 +428,10 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
||||
}
|
||||
|
||||
// Determine quantization type for this tensor (empty string if not quantizing)
|
||||
// GetTensorQuantization handles mixed-precision (e.g., Q8 for attention, Q4 for FFN)
|
||||
quantizeType := ""
|
||||
if quantize != "" && ShouldQuantizeTensor(tensorName, td.Shape) {
|
||||
quantizeType = quantize
|
||||
if quantize != "" {
|
||||
quantizeType = GetTensorQuantization(tensorName, td.Shape, quantize)
|
||||
}
|
||||
|
||||
// Store as minimal safetensors format (88 bytes header overhead)
|
||||
@@ -388,6 +487,23 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
||||
return fmt.Errorf("config.json not found in %s", modelDir)
|
||||
}
|
||||
|
||||
// Create model_index.json with quantization info if quantizing
|
||||
if quantize != "" {
|
||||
modelIndex := map[string]any{
|
||||
"quantization": strings.ToUpper(quantize),
|
||||
"group_size": getQuantGroupSize(quantize),
|
||||
}
|
||||
indexData, err := json.MarshalIndent(modelIndex, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal model_index.json: %w", err)
|
||||
}
|
||||
indexLayer, err := createLayer(strings.NewReader(string(indexData)), "application/vnd.ollama.image.json", "model_index.json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create model_index.json layer: %w", err)
|
||||
}
|
||||
layers = append(layers, indexLayer)
|
||||
}
|
||||
|
||||
fn(fmt.Sprintf("writing manifest for %s", modelName))
|
||||
|
||||
if err := writeManifest(modelName, configLayer, layers); err != nil {
|
||||
|
||||
@@ -536,41 +536,51 @@ func TestShouldQuantize(t *testing.T) {
|
||||
|
||||
func TestShouldQuantizeTensor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tensor string
|
||||
shape []int32
|
||||
want bool
|
||||
name string
|
||||
tensor string
|
||||
shape []int32
|
||||
quantize string
|
||||
want bool
|
||||
}{
|
||||
// 2D tensors with sufficient size should be quantized
|
||||
{"large 2D weight", "q_proj.weight", []int32{4096, 4096}, true},
|
||||
{"medium 2D weight", "small_proj.weight", []int32{128, 128}, true},
|
||||
{"large 2D weight fp8", "q_proj.weight", []int32{4096, 4096}, "fp8", true},
|
||||
{"medium 2D weight fp8", "small_proj.weight", []int32{128, 128}, "fp8", true},
|
||||
{"large 2D weight nvfp4", "q_proj.weight", []int32{4096, 4096}, "nvfp4", true},
|
||||
|
||||
// Small tensors should not be quantized (< 1024 elements)
|
||||
{"tiny 2D weight", "tiny.weight", []int32{16, 16}, false},
|
||||
{"small 2D weight", "small.weight", []int32{31, 31}, false},
|
||||
{"tiny 2D weight", "tiny.weight", []int32{16, 16}, "fp8", false},
|
||||
{"small 2D weight", "small.weight", []int32{31, 31}, "fp8", false},
|
||||
|
||||
// 1D tensors should not be quantized
|
||||
{"1D tensor", "layer_norm.weight", []int32{4096}, false},
|
||||
{"1D tensor", "layer_norm.weight", []int32{4096}, "fp8", false},
|
||||
|
||||
// 3D+ tensors should not be quantized
|
||||
{"3D tensor", "conv.weight", []int32{64, 64, 3}, false},
|
||||
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, false},
|
||||
{"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false},
|
||||
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false},
|
||||
|
||||
// Embeddings should not be quantized regardless of shape
|
||||
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, false},
|
||||
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false},
|
||||
|
||||
// Norms should not be quantized regardless of shape
|
||||
{"norm 2D", "layer_norm.weight", []int32{4096, 1}, false},
|
||||
{"norm 2D", "layer_norm.weight", []int32{4096, 1}, "fp8", false},
|
||||
|
||||
// Biases should not be quantized
|
||||
{"bias 2D", "proj.bias", []int32{4096, 1}, false},
|
||||
{"bias 2D", "proj.bias", []int32{4096, 1}, "fp8", false},
|
||||
|
||||
// Group size divisibility tests
|
||||
// FP8/FP4 require divisible by 32
|
||||
{"not divisible by 32 fp8", "proj.weight", []int32{128, 48}, "fp8", false},
|
||||
{"divisible by 32 fp8", "proj.weight", []int32{128, 64}, "fp8", true},
|
||||
// NVFP4 requires divisible by 16
|
||||
{"not divisible by 16 nvfp4", "proj.weight", []int32{128, 24}, "nvfp4", false},
|
||||
{"divisible by 16 nvfp4", "proj.weight", []int32{128, 48}, "nvfp4", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ShouldQuantizeTensor(tt.tensor, tt.shape)
|
||||
got := ShouldQuantizeTensor(tt.tensor, tt.shape, tt.quantize)
|
||||
if got != tt.want {
|
||||
t.Errorf("ShouldQuantizeTensor(%q, %v) = %v, want %v", tt.tensor, tt.shape, got, tt.want)
|
||||
t.Errorf("ShouldQuantizeTensor(%q, %v, %q) = %v, want %v", tt.tensor, tt.shape, tt.quantize, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -741,7 +751,7 @@ func TestCreateImageGenModel_WithQuantize(t *testing.T) {
|
||||
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateImageGenModel("test-imagegen", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
err := CreateImageGenModel("test-imagegen", dir, "q8", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateImageGenModel failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -15,15 +15,15 @@ import (
|
||||
// CreateImageGenModel imports an image generation model from a directory.
|
||||
// Stores each tensor as a separate blob for fine-grained deduplication.
|
||||
// If quantize is specified, linear weights in transformer/text_encoder are quantized.
|
||||
// Supported quantization types: fp8 (or empty for no quantization).
|
||||
// Supported quantization types: q4, q8, nvfp4, mxfp8 (or empty for no quantization).
|
||||
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
|
||||
func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
// Validate quantization type
|
||||
switch quantize {
|
||||
case "", "fp4", "fp8":
|
||||
case "", "q4", "q8", "nvfp4", "mxfp8":
|
||||
// valid
|
||||
default:
|
||||
return fmt.Errorf("unsupported quantization type %q: supported types are fp4, fp8", quantize)
|
||||
return fmt.Errorf("unsupported quantization type %q: supported types are q4, q8, nvfp4, mxfp8", quantize)
|
||||
}
|
||||
|
||||
var layers []LayerInfo
|
||||
@@ -89,7 +89,7 @@ func CreateImageGenModel(modelName, modelDir, quantize string, createLayer Layer
|
||||
|
||||
// Determine quantization type for this tensor (empty string if not quantizing)
|
||||
quantizeType := ""
|
||||
if quantize != "" && ShouldQuantize(tensorName, component) && canQuantizeShape(td.Shape) {
|
||||
if quantize != "" && ShouldQuantize(tensorName, component) && canQuantizeShape(td.Shape, quantize) {
|
||||
quantizeType = quantize
|
||||
}
|
||||
|
||||
@@ -213,10 +213,18 @@ func CreateImageGenModel(modelName, modelDir, quantize string, createLayer Layer
|
||||
}
|
||||
|
||||
// canQuantizeShape returns true if a tensor shape is compatible with MLX quantization.
|
||||
// MLX requires the last dimension to be divisible by the group size (32).
|
||||
func canQuantizeShape(shape []int32) bool {
|
||||
// MLX requires the last dimension to be divisible by the group size.
|
||||
// nvfp4: 16, q4/mxfp8: 32, q8: 64
|
||||
func canQuantizeShape(shape []int32, quantize string) bool {
|
||||
if len(shape) < 2 {
|
||||
return false
|
||||
}
|
||||
return shape[len(shape)-1]%32 == 0
|
||||
groupSize := int32(32)
|
||||
switch strings.ToUpper(quantize) {
|
||||
case "NVFP4":
|
||||
groupSize = 16
|
||||
case "Q8":
|
||||
groupSize = 64
|
||||
}
|
||||
return shape[len(shape)-1]%groupSize == 0
|
||||
}
|
||||
|
||||
16
x/imagegen/cache/cache.go
vendored
16
x/imagegen/cache/cache.go
vendored
@@ -9,6 +9,7 @@ type Cache interface {
|
||||
Offset() int
|
||||
Len() int
|
||||
State() []*mlx.Array
|
||||
Reset()
|
||||
}
|
||||
|
||||
type KVCache struct {
|
||||
@@ -63,6 +64,13 @@ func (c *KVCache) State() []*mlx.Array {
|
||||
func (c *KVCache) Offset() int { return c.offset }
|
||||
func (c *KVCache) Len() int { return c.offset }
|
||||
|
||||
// Reset clears the cache state for a new generation session
|
||||
func (c *KVCache) Reset() {
|
||||
c.keys = nil
|
||||
c.values = nil
|
||||
c.offset = 0
|
||||
}
|
||||
|
||||
// RotatingKVCache implements sliding window attention with bounded memory
|
||||
type RotatingKVCache struct {
|
||||
keys, values *mlx.Array
|
||||
@@ -154,3 +162,11 @@ func (c *RotatingKVCache) State() []*mlx.Array {
|
||||
|
||||
func (c *RotatingKVCache) Offset() int { return c.offset }
|
||||
func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) }
|
||||
|
||||
// Reset clears the cache state for a new generation session
|
||||
func (c *RotatingKVCache) Reset() {
|
||||
c.keys = nil
|
||||
c.values = nil
|
||||
c.offset = 0
|
||||
c.idx = 0
|
||||
}
|
||||
|
||||
@@ -10,7 +10,10 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -75,6 +78,7 @@ Image Generation Flags (experimental):
|
||||
// RunCLI handles the CLI for image generation models.
|
||||
// Returns true if it handled the request, false if the caller should continue with normal flow.
|
||||
// Supports flags: --width, --height, --steps, --seed, --negative
|
||||
// Image paths can be included in the prompt and will be extracted automatically.
|
||||
func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, keepAlive *api.Duration) error {
|
||||
// Get options from flags (with env var defaults)
|
||||
opts := DefaultOptions()
|
||||
@@ -111,9 +115,16 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
|
||||
return err
|
||||
}
|
||||
|
||||
// Extract any image paths from the prompt
|
||||
prompt, images, err := extractFileData(prompt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req := &api.GenerateRequest{
|
||||
Model: modelName,
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Width: int32(opts.Width),
|
||||
Height: int32(opts.Height),
|
||||
Steps: int32(opts.Steps),
|
||||
@@ -254,14 +265,33 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
|
||||
printCurrentSettings(opts)
|
||||
continue
|
||||
case strings.HasPrefix(line, "/"):
|
||||
fmt.Fprintf(os.Stderr, "Unknown command: %s (try /help)\n", line)
|
||||
// Check if it's a file path, not a command
|
||||
args := strings.Fields(line)
|
||||
isFile := false
|
||||
for _, f := range extractFileNames(line) {
|
||||
if strings.HasPrefix(f, args[0]) {
|
||||
isFile = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isFile {
|
||||
fmt.Fprintf(os.Stderr, "Unknown command: %s (try /help)\n", args[0])
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Extract any image paths from the input
|
||||
prompt, images, err := extractFileData(line)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Generate image with current options
|
||||
req := &api.GenerateRequest{
|
||||
Model: modelName,
|
||||
Prompt: line,
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Width: int32(opts.Width),
|
||||
Height: int32(opts.Height),
|
||||
Steps: int32(opts.Steps),
|
||||
@@ -486,3 +516,61 @@ func displayImageInTerminal(imagePath string) bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// extractFileNames finds image file paths in the input string.
|
||||
func extractFileNames(input string) []string {
|
||||
// Regex to match file paths with image extensions
|
||||
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp)\b`
|
||||
re := regexp.MustCompile(regexPattern)
|
||||
return re.FindAllString(input, -1)
|
||||
}
|
||||
|
||||
// extractFileData extracts image data from file paths found in the input.
|
||||
// Returns the cleaned prompt (with file paths removed) and the image data.
|
||||
func extractFileData(input string) (string, []api.ImageData, error) {
|
||||
filePaths := extractFileNames(input)
|
||||
var imgs []api.ImageData
|
||||
|
||||
for _, fp := range filePaths {
|
||||
// Normalize shell escapes
|
||||
nfp := strings.ReplaceAll(fp, "\\ ", " ")
|
||||
nfp = strings.ReplaceAll(nfp, "\\(", "(")
|
||||
nfp = strings.ReplaceAll(nfp, "\\)", ")")
|
||||
nfp = strings.ReplaceAll(nfp, "%20", " ")
|
||||
|
||||
data, err := getImageData(nfp)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
continue
|
||||
} else if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
|
||||
input = strings.ReplaceAll(input, fp, "")
|
||||
imgs = append(imgs, data)
|
||||
}
|
||||
return strings.TrimSpace(input), imgs, nil
|
||||
}
|
||||
|
||||
// getImageData reads and validates image data from a file.
|
||||
func getImageData(filePath string) ([]byte, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
buf := make([]byte, 512)
|
||||
_, err = file.Read(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
contentType := http.DetectContentType(buf)
|
||||
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp"}
|
||||
if !slices.Contains(allowedTypes, contentType) {
|
||||
return nil, fmt.Errorf("invalid image type: %s", contentType)
|
||||
}
|
||||
|
||||
// Re-read the full file
|
||||
return os.ReadFile(filePath)
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"image"
|
||||
"image/color"
|
||||
"image/draw"
|
||||
_ "image/jpeg"
|
||||
"image/png"
|
||||
"os"
|
||||
@@ -111,6 +113,7 @@ func clampF(v, min, max float32) float32 {
|
||||
}
|
||||
|
||||
// DecodeImage decodes image bytes with EXIF orientation applied.
|
||||
// Transparent images are composited onto a white background.
|
||||
func DecodeImage(data []byte) (image.Image, error) {
|
||||
orientation := readJPEGOrientation(data)
|
||||
|
||||
@@ -119,9 +122,33 @@ func DecodeImage(data []byte) (image.Image, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
img = flattenAlpha(img)
|
||||
return applyOrientation(img, orientation), nil
|
||||
}
|
||||
|
||||
// flattenAlpha composites an image onto a white background,
|
||||
// removing any transparency. This is needed because image
|
||||
// generation models don't handle alpha channels well.
|
||||
func flattenAlpha(img image.Image) image.Image {
|
||||
if _, ok := img.(*image.RGBA); !ok {
|
||||
if _, ok := img.(*image.NRGBA); !ok {
|
||||
// No alpha channel, return as-is
|
||||
return img
|
||||
}
|
||||
}
|
||||
|
||||
bounds := img.Bounds()
|
||||
dst := image.NewRGBA(bounds)
|
||||
|
||||
// Fill with white background
|
||||
draw.Draw(dst, bounds, &image.Uniform{color.White}, image.Point{}, draw.Src)
|
||||
|
||||
// Composite the image on top
|
||||
draw.Draw(dst, bounds, img, bounds.Min, draw.Over)
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// readJPEGOrientation extracts EXIF orientation from JPEG bytes.
|
||||
// Returns 1 (normal) for non-JPEG or if orientation not found.
|
||||
func readJPEGOrientation(data []byte) int {
|
||||
|
||||
@@ -102,14 +102,17 @@ func (m *ModelManifest) BlobPath(digest string) string {
|
||||
return filepath.Join(m.BlobDir, blobName)
|
||||
}
|
||||
|
||||
// GetTensorLayers returns all tensor layers for a given component.
|
||||
// Component should be "text_encoder", "transformer", or "vae".
|
||||
// Tensor names are path-style: "component/tensor_name" (e.g., "text_encoder/model.embed_tokens.weight").
|
||||
// GetTensorLayers returns tensor layers, optionally filtered by component.
|
||||
// If component is empty, returns all tensor layers (for LLM models).
|
||||
// If component is specified (e.g., "text_encoder", "transformer", "vae"),
|
||||
// returns only layers with that prefix.
|
||||
func (m *ModelManifest) GetTensorLayers(component string) []ManifestLayer {
|
||||
prefix := component + "/"
|
||||
var layers []ManifestLayer
|
||||
for _, layer := range m.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" && strings.HasPrefix(layer.Name, prefix) {
|
||||
if layer.MediaType != "application/vnd.ollama.image.tensor" {
|
||||
continue
|
||||
}
|
||||
if component == "" || strings.HasPrefix(layer.Name, component+"/") {
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
}
|
||||
@@ -161,6 +164,17 @@ func (m *ModelManifest) HasTensorLayers() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// TotalTensorSize returns the total size in bytes of all tensor layers.
|
||||
func (m *ModelManifest) TotalTensorSize() int64 {
|
||||
var total int64
|
||||
for _, layer := range m.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||
total += layer.Size
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// ModelInfo contains metadata about an image generation model.
|
||||
type ModelInfo struct {
|
||||
Architecture string
|
||||
@@ -195,7 +209,7 @@ func GetModelInfo(modelName string) (*ModelInfo, error) {
|
||||
if info.Quantization == "" {
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if strings.HasSuffix(layer.Name, ".weight_scale") {
|
||||
info.Quantization = "FP8"
|
||||
info.Quantization = "Q8"
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,37 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestTotalTensorSize(t *testing.T) {
|
||||
m := &ModelManifest{
|
||||
Manifest: &Manifest{
|
||||
Layers: []ManifestLayer{
|
||||
{MediaType: "application/vnd.ollama.image.tensor", Size: 1000},
|
||||
{MediaType: "application/vnd.ollama.image.tensor", Size: 2000},
|
||||
{MediaType: "application/vnd.ollama.image.json", Size: 500}, // not a tensor
|
||||
{MediaType: "application/vnd.ollama.image.tensor", Size: 3000},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := m.TotalTensorSize()
|
||||
want := int64(6000)
|
||||
if got != want {
|
||||
t.Errorf("TotalTensorSize() = %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTotalTensorSizeEmpty(t *testing.T) {
|
||||
m := &ModelManifest{
|
||||
Manifest: &Manifest{
|
||||
Layers: []ManifestLayer{},
|
||||
},
|
||||
}
|
||||
|
||||
if got := m.TotalTensorSize(); got != 0 {
|
||||
t.Errorf("TotalTensorSize() = %d, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManifestAndBlobDirsRespectOLLAMAModels(t *testing.T) {
|
||||
modelsDir := filepath.Join(t.TempDir(), "models")
|
||||
|
||||
|
||||
@@ -16,18 +16,9 @@ import (
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// GB is a convenience constant for gigabytes.
|
||||
const GB = 1024 * 1024 * 1024
|
||||
|
||||
// SupportedBackends lists the backends that support image generation.
|
||||
var SupportedBackends = []string{"metal", "cuda", "cpu"}
|
||||
|
||||
// modelVRAMEstimates maps pipeline class names to their estimated VRAM requirements.
|
||||
var modelVRAMEstimates = map[string]uint64{
|
||||
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
|
||||
"FluxPipeline": 20 * GB, // ~20GB for Flux
|
||||
}
|
||||
|
||||
// CheckPlatformSupport validates that image generation is supported on the current platform.
|
||||
// Returns nil if supported, or an error describing why it's not supported.
|
||||
func CheckPlatformSupport() error {
|
||||
@@ -47,17 +38,6 @@ func CheckPlatformSupport() error {
|
||||
}
|
||||
}
|
||||
|
||||
// CheckMemoryRequirements validates that there's enough memory for image generation.
|
||||
// Returns nil if memory is sufficient, or an error if not.
|
||||
func CheckMemoryRequirements(modelName string, availableMemory uint64) error {
|
||||
required := EstimateVRAM(modelName)
|
||||
if availableMemory < required {
|
||||
return fmt.Errorf("insufficient memory for image generation: need %d GB, have %d GB",
|
||||
required/GB, availableMemory/GB)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResolveModelName checks if a model name is a known image generation model.
|
||||
// Returns the normalized model name if found, empty string otherwise.
|
||||
func ResolveModelName(modelName string) string {
|
||||
@@ -68,16 +48,6 @@ func ResolveModelName(modelName string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// EstimateVRAM returns the estimated VRAM needed for an image generation model.
|
||||
// Returns a conservative default of 21GB if the model type cannot be determined.
|
||||
func EstimateVRAM(modelName string) uint64 {
|
||||
className := DetectModelType(modelName)
|
||||
if estimate, ok := modelVRAMEstimates[className]; ok {
|
||||
return estimate
|
||||
}
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
// DetectModelType reads model_index.json and returns the model type.
|
||||
// Checks both "architecture" (Ollama format) and "_class_name" (diffusers format).
|
||||
// Returns empty string if detection fails.
|
||||
|
||||
@@ -30,69 +30,6 @@ func TestCheckPlatformSupport(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckMemoryRequirements(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
availableMemory uint64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "sufficient memory",
|
||||
availableMemory: 32 * GB,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "exactly enough memory",
|
||||
availableMemory: 21 * GB,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "insufficient memory",
|
||||
availableMemory: 16 * GB,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "zero memory",
|
||||
availableMemory: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Use a non-existent model name which will default to 21GB estimate
|
||||
err := CheckMemoryRequirements("nonexistent-model", tt.availableMemory)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("CheckMemoryRequirements() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelVRAMEstimates(t *testing.T) {
|
||||
// Verify the VRAM estimates map has expected entries
|
||||
expected := map[string]uint64{
|
||||
"ZImagePipeline": 21 * GB,
|
||||
"FluxPipeline": 20 * GB,
|
||||
}
|
||||
|
||||
for name, expectedVRAM := range expected {
|
||||
if actual, ok := modelVRAMEstimates[name]; !ok {
|
||||
t.Errorf("Missing VRAM estimate for %s", name)
|
||||
} else if actual != expectedVRAM {
|
||||
t.Errorf("VRAM estimate for %s = %d GB, want %d GB", name, actual/GB, expectedVRAM/GB)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateVRAMDefault(t *testing.T) {
|
||||
// Non-existent model should return default 21GB
|
||||
vram := EstimateVRAM("nonexistent-model-that-does-not-exist")
|
||||
if vram != 21*GB {
|
||||
t.Errorf("EstimateVRAM() = %d GB, want 21 GB", vram/GB)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveModelName(t *testing.T) {
|
||||
// Non-existent model should return empty string
|
||||
result := ResolveModelName("nonexistent-model")
|
||||
|
||||
@@ -991,6 +991,19 @@ func Concat(a, b *Array, axis int) *Array {
|
||||
return Concatenate([]*Array{a, b}, axis)
|
||||
}
|
||||
|
||||
// Stack stacks arrays along a new axis (axis 0 by default)
|
||||
func Stack(arrays []*Array, axis int) *Array {
|
||||
handles := make([]C.mlx_array, len(arrays))
|
||||
for i, arr := range arrays {
|
||||
handles[i] = arr.c
|
||||
}
|
||||
vec := C.mlx_vector_array_new_data(&handles[0], C.size_t(len(handles)))
|
||||
res := C.mlx_array_new()
|
||||
C.mlx_stack_axis(&res, vec, C.int(axis), C.default_stream())
|
||||
C.mlx_vector_array_free(vec)
|
||||
return newArray(res)
|
||||
}
|
||||
|
||||
// Slice slices the array
|
||||
func Slice(a *Array, start, stop []int32) *Array {
|
||||
n := len(start)
|
||||
|
||||
@@ -177,6 +177,20 @@ func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateImageWithInputs implements runner.ImageEditModel interface.
|
||||
// It generates an image conditioned on the provided input images for image editing.
|
||||
func (m *Model) GenerateImageWithInputs(ctx context.Context, prompt string, width, height int32, steps int, seed int64, inputImages []image.Image, progress func(step, total int)) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(ctx, &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
InputImages: inputImages,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
// MaxOutputPixels is the maximum output resolution (4 megapixels, ~2048x2048)
|
||||
const MaxOutputPixels = 2048 * 2048
|
||||
|
||||
|
||||
840
x/imagegen/models/glm4_moe_lite/glm4_moe_lite.go
Normal file
840
x/imagegen/models/glm4_moe_lite/glm4_moe_lite.go
Normal file
@@ -0,0 +1,840 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package glm4_moe_lite provides the GLM4-MoE-Lite implementation for MLX.
|
||||
// This model uses Multi-head Latent Attention (MLA) and Mixture of Experts (MoE).
|
||||
package glm4_moe_lite
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// RopeScaling holds RoPE scaling configuration
|
||||
type RopeScaling struct {
|
||||
Factor float32 `json:"factor"`
|
||||
MscaleAllDim float32 `json:"mscale_all_dim"`
|
||||
}
|
||||
|
||||
// Config holds GLM4-MoE-Lite model configuration
|
||||
type Config struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
MoEIntermediateSize int32 `json:"moe_intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
||||
AttentionBias bool `json:"attention_bias"`
|
||||
|
||||
// MLA (Multi-head Latent Attention) parameters
|
||||
QLoraRank int32 `json:"q_lora_rank"`
|
||||
KVLoraRank int32 `json:"kv_lora_rank"`
|
||||
QKRopeHeadDim int32 `json:"qk_rope_head_dim"`
|
||||
QKNopeHeadDim int32 `json:"qk_nope_head_dim"`
|
||||
VHeadDim int32 `json:"v_head_dim"`
|
||||
|
||||
// MoE parameters
|
||||
NRoutedExperts int32 `json:"n_routed_experts"`
|
||||
NSharedExperts int32 `json:"n_shared_experts"`
|
||||
NumExpertsPerTok int32 `json:"num_experts_per_tok"`
|
||||
RoutedScalingFactor float32 `json:"routed_scaling_factor"`
|
||||
NormTopKProb bool `json:"norm_topk_prob"`
|
||||
FirstKDenseReplace int32 `json:"first_k_dense_replace"`
|
||||
NGroup int32 `json:"n_group"`
|
||||
TopKGroup int32 `json:"topk_group"`
|
||||
|
||||
// RoPE scaling
|
||||
RopeScaling *RopeScaling `json:"rope_scaling"`
|
||||
|
||||
// Quantization parameters (set during load based on model quantization)
|
||||
QuantGroupSize int `json:"-"` // Group size for quantization (default 64)
|
||||
QuantBits int `json:"-"` // Bits per weight (4 or 8)
|
||||
QuantMode string `json:"-"` // Quantization mode ("affine", etc.)
|
||||
|
||||
// Computed fields
|
||||
QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim
|
||||
Scale float32 `json:"-"` // 1/sqrt(QHeadDim) with mscale adjustment
|
||||
}
|
||||
|
||||
// MLAAttention implements Multi-head Latent Attention with absorption.
|
||||
// This uses absorbed MLA which operates in latent space for reduced KV cache.
|
||||
type MLAAttention struct {
|
||||
// Low-rank query projections
|
||||
QAProj nn.LinearLayer `weight:"self_attn.q_a_proj"`
|
||||
QALayerNorm *nn.RMSNorm `weight:"self_attn.q_a_layernorm"`
|
||||
QBProj nn.LinearLayer `weight:"self_attn.q_b_proj"`
|
||||
|
||||
// Low-rank KV projections (with shared rope component)
|
||||
KVAProjWithMQA nn.LinearLayer `weight:"self_attn.kv_a_proj_with_mqa"`
|
||||
KVALayerNorm *nn.RMSNorm `weight:"self_attn.kv_a_layernorm"`
|
||||
|
||||
// Absorbed MLA projections (derived from kv_b_proj)
|
||||
// EmbedQ: projects q_nope to latent space [num_heads, kv_lora_rank, qk_nope_head_dim]
|
||||
// UnembedOut: projects attention output from latent space [num_heads, v_head_dim, kv_lora_rank]
|
||||
EmbedQ *nn.MultiLinear `weight:"-"`
|
||||
UnembedOut *nn.MultiLinear `weight:"-"`
|
||||
|
||||
// Output projection
|
||||
OProj nn.LinearLayer `weight:"self_attn.o_proj"`
|
||||
}
|
||||
|
||||
// Forward computes absorbed MLA attention output.
|
||||
// This operates in latent space for reduced KV cache memory.
|
||||
func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
// Query path: q_a_proj -> layernorm -> q_b_proj
|
||||
q := a.QAProj.Forward(x)
|
||||
q = a.QALayerNorm.Forward(q, cfg.RMSNormEps)
|
||||
q = a.QBProj.Forward(q)
|
||||
|
||||
// Reshape Q: [B, L, num_heads * q_head_dim] -> [B, num_heads, L, q_head_dim]
|
||||
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.QHeadDim)
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
|
||||
// Split Q into nope and rope parts
|
||||
qNope := mlx.Slice(q, []int32{0, 0, 0, 0}, []int32{B, cfg.NumAttentionHeads, L, cfg.QKNopeHeadDim})
|
||||
qPE := mlx.Slice(q, []int32{0, 0, 0, cfg.QKNopeHeadDim}, []int32{B, cfg.NumAttentionHeads, L, cfg.QHeadDim})
|
||||
|
||||
// KV path: get compressed KV and k_pe
|
||||
compressedKV := a.KVAProjWithMQA.Forward(x)
|
||||
|
||||
// Split into compressed_kv and k_pe (shared rope component)
|
||||
kvCompressed := mlx.Slice(compressedKV, []int32{0, 0, 0}, []int32{B, L, cfg.KVLoraRank})
|
||||
kPE := mlx.Slice(compressedKV, []int32{0, 0, cfg.KVLoraRank}, []int32{B, L, cfg.KVLoraRank + cfg.QKRopeHeadDim})
|
||||
|
||||
// k_pe is shared across heads (MQA-style): [B, L, rope_dim] -> [B, 1, L, rope_dim]
|
||||
kPE = mlx.Reshape(kPE, B, L, 1, cfg.QKRopeHeadDim)
|
||||
kPE = mlx.Transpose(kPE, 0, 2, 1, 3)
|
||||
|
||||
// Apply layernorm to get kv latent representation
|
||||
kvLatent := a.KVALayerNorm.Forward(kvCompressed, cfg.RMSNormEps)
|
||||
// kvLatent: [B, L, kv_lora_rank] -> [B, 1, L, kv_lora_rank] for broadcasting
|
||||
kvLatent = mlx.ExpandDims(kvLatent, 1)
|
||||
|
||||
// Apply RoPE to the rope parts
|
||||
offset := 0
|
||||
if c != nil {
|
||||
offset = c.Offset()
|
||||
}
|
||||
qPE = mlx.RoPE(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
|
||||
kPE = mlx.RoPE(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
|
||||
|
||||
// ABSORBED MLA: project q_nope to latent space
|
||||
// qNope: [B, num_heads, L, qk_nope_head_dim]
|
||||
// EmbedQ: [num_heads, kv_lora_rank, qk_nope_head_dim]
|
||||
// Result: [B, num_heads, L, kv_lora_rank]
|
||||
qLatent := a.EmbedQ.Forward(qNope)
|
||||
|
||||
// Keys = concat(kvLatent, kPE)
|
||||
// kvLatent: [B, 1, L, kv_lora_rank]
|
||||
// kPE: [B, 1, L, qk_rope_head_dim]
|
||||
// keys: [B, 1, L, kv_lora_rank + qk_rope_head_dim]
|
||||
keys := mlx.Concatenate([]*mlx.Array{kvLatent, kPE}, 3)
|
||||
|
||||
// Cache the smaller latent representation
|
||||
// We cache keys (latent + rope) and use empty values since values are derived from keys
|
||||
cachedL := L
|
||||
if c != nil {
|
||||
// Create placeholder values with 0 dims for cache (we don't actually use cached values)
|
||||
placeholderValues := mlx.Zeros([]int32{B, 1, L, 0}, mlx.DtypeFloat32)
|
||||
keys, _ = c.Update(keys, placeholderValues, int(L))
|
||||
cachedL = int32(keys.Shape()[2])
|
||||
}
|
||||
|
||||
// Values are the first kv_lora_rank dims of keys (slice off rope part)
|
||||
values := mlx.Slice(keys, []int32{0, 0, 0, 0}, []int32{B, 1, cachedL, cfg.KVLoraRank})
|
||||
|
||||
// Queries = concat(qLatent, qPE)
|
||||
// qLatent: [B, num_heads, L, kv_lora_rank]
|
||||
// qPE: [B, num_heads, L, qk_rope_head_dim]
|
||||
// queries: [B, num_heads, L, kv_lora_rank + qk_rope_head_dim]
|
||||
queries := mlx.Concatenate([]*mlx.Array{qLatent, qPE}, 3)
|
||||
|
||||
// Attention in latent space
|
||||
// queries: [B, num_heads, L, kv_lora_rank + rope_dim]
|
||||
// keys: [B, 1, cachedL, kv_lora_rank + rope_dim]
|
||||
// values: [B, 1, cachedL, kv_lora_rank]
|
||||
out := mlx.ScaledDotProductAttention(queries, keys, values, cfg.Scale, L > 1)
|
||||
|
||||
// ABSORBED MLA: unembed from latent space
|
||||
// out: [B, num_heads, L, kv_lora_rank]
|
||||
// UnembedOut: [num_heads, v_head_dim, kv_lora_rank]
|
||||
// Result: [B, num_heads, L, v_head_dim]
|
||||
out = a.UnembedOut.Forward(out)
|
||||
|
||||
// Reshape back: [B, num_heads, L, v_head_dim] -> [B, L, num_heads * v_head_dim]
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.VHeadDim)
|
||||
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
||||
// DenseMLP implements the standard SwiGLU MLP for dense layers
|
||||
type DenseMLP struct {
|
||||
GateProj nn.LinearLayer `weight:"mlp.gate_proj"`
|
||||
UpProj nn.LinearLayer `weight:"mlp.up_proj"`
|
||||
DownProj nn.LinearLayer `weight:"mlp.down_proj"`
|
||||
}
|
||||
|
||||
// Forward applies the SwiGLU MLP
|
||||
func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
gate := mlx.SiLU(m.GateProj.Forward(x))
|
||||
up := m.UpProj.Forward(x)
|
||||
return m.DownProj.Forward(mlx.Mul(gate, up))
|
||||
}
|
||||
|
||||
// MoEGate implements the expert gating mechanism
|
||||
type MoEGate struct {
|
||||
Gate nn.LinearLayer `weight:"mlp.gate"`
|
||||
EScoreCorrectionBias *mlx.Array `weight:"mlp.gate.e_score_correction_bias,optional"`
|
||||
}
|
||||
|
||||
// Forward computes expert selection indices and scores
|
||||
func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) {
|
||||
// Compute gate logits through linear layer (handles both quantized and non-quantized)
|
||||
gates := g.Gate.Forward(x)
|
||||
|
||||
// Sigmoid scoring
|
||||
scores := mlx.Sigmoid(gates)
|
||||
origScores := scores
|
||||
|
||||
// Add correction bias if present
|
||||
if g.EScoreCorrectionBias != nil {
|
||||
scores = mlx.Add(scores, g.EScoreCorrectionBias)
|
||||
}
|
||||
|
||||
// Group-wise expert selection (simplified for n_group=1)
|
||||
// Select top-k experts
|
||||
topK := cfg.NumExpertsPerTok
|
||||
negScores := mlx.Neg(scores)
|
||||
inds := mlx.Argpartition(negScores, int(topK)-1, -1)
|
||||
|
||||
shape := inds.Shape()
|
||||
inds = mlx.Slice(inds, []int32{0, 0, 0}, []int32{shape[0], shape[1], topK})
|
||||
|
||||
// Get scores for selected experts
|
||||
scores = mlx.TakeAlongAxis(origScores, inds, -1)
|
||||
|
||||
// Normalize if configured
|
||||
if topK > 1 && cfg.NormTopKProb {
|
||||
sumScores := mlx.Sum(scores, -1, true)
|
||||
scores = mlx.Div(scores, sumScores)
|
||||
}
|
||||
|
||||
// Apply routing scaling factor
|
||||
scores = mlx.MulScalar(scores, cfg.RoutedScalingFactor)
|
||||
|
||||
return inds, scores
|
||||
}
|
||||
|
||||
// SwitchMLP implements the MoE expert computation using stacked weights
|
||||
// Note: No weight tags - these are populated manually by stacking expert weights
|
||||
type SwitchMLP struct {
|
||||
// Dequantized weights (used when GatherQMM not available)
|
||||
GateWeight *mlx.Array
|
||||
UpWeight *mlx.Array
|
||||
DownWeight *mlx.Array
|
||||
|
||||
// Quantized weights (used with GatherQMM for 4/8-bit affine)
|
||||
GateWeightQ, GateScales, GateBiases *mlx.Array
|
||||
UpWeightQ, UpScales, UpBiases *mlx.Array
|
||||
DownWeightQ, DownScales, DownBiases *mlx.Array
|
||||
|
||||
// Quantization bits per projection (supports mixed precision Q4/Q8)
|
||||
GateBits int
|
||||
UpBits int
|
||||
DownBits int
|
||||
|
||||
// Quantization group size per projection (detected from tensor shapes)
|
||||
GateGroupSize int
|
||||
UpGroupSize int
|
||||
DownGroupSize int
|
||||
|
||||
// If true, use GatherQMM with quantized weights
|
||||
UseQuantized bool
|
||||
}
|
||||
|
||||
// Forward applies the switched expert MLP
|
||||
func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B, L := shape[0], shape[1]
|
||||
topK := cfg.NumExpertsPerTok
|
||||
|
||||
// Expand x for expert computation: [B, L, D] -> [B, L, 1, 1, D]
|
||||
xExpanded := mlx.ExpandDims(mlx.ExpandDims(x, -2), -2)
|
||||
|
||||
// Flatten for gather_mm: [B*L, 1, 1, D]
|
||||
xFlat := mlx.Reshape(xExpanded, B*L, 1, 1, cfg.HiddenSize)
|
||||
|
||||
// Flatten indices: [B, L, topK] -> [B*L, topK]
|
||||
idxFlat := mlx.Reshape(indices, B*L, topK)
|
||||
|
||||
// Sort for efficient gather (when we have many tokens)
|
||||
doSort := B*L >= 64
|
||||
var invOrder *mlx.Array
|
||||
n := B * L * topK
|
||||
|
||||
if doSort {
|
||||
idxAll := mlx.Flatten(idxFlat)
|
||||
order := mlx.Argsort(idxAll, 0)
|
||||
invOrder = mlx.Argsort(order, 0)
|
||||
// Reorder x based on sorted indices
|
||||
xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, topK), 0), 1)
|
||||
idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1)
|
||||
}
|
||||
|
||||
var gate, up, hidden, down *mlx.Array
|
||||
|
||||
if s.UseQuantized {
|
||||
// Use GatherQMM for quantized weights (faster, keeps weights quantized)
|
||||
// Each projection may have different bits and group sizes (mixed precision: Q4 for gate/up, Q8 for down)
|
||||
gate = mlx.GatherQMM(xFlat, s.GateWeightQ, s.GateScales, s.GateBiases,
|
||||
nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort)
|
||||
up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
|
||||
nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
|
||||
|
||||
hidden = mlx.Mul(mlx.SiLU(gate), up)
|
||||
|
||||
down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
|
||||
nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
|
||||
} else {
|
||||
// Use GatherMM for dequantized/non-quantized weights
|
||||
gate = mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort)
|
||||
up = mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort)
|
||||
|
||||
hidden = mlx.Mul(mlx.SiLU(gate), up)
|
||||
|
||||
down = mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort)
|
||||
}
|
||||
|
||||
// Unsort if we sorted
|
||||
if doSort {
|
||||
down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, topK, cfg.HiddenSize)
|
||||
} else {
|
||||
down = mlx.Squeeze(down, 2)
|
||||
}
|
||||
|
||||
return mlx.Reshape(down, B, L, topK, cfg.HiddenSize)
|
||||
}
|
||||
|
||||
// SharedExperts implements the shared expert MLP
|
||||
type SharedExperts struct {
|
||||
GateProj nn.LinearLayer `weight:"mlp.shared_experts.gate_proj"`
|
||||
UpProj nn.LinearLayer `weight:"mlp.shared_experts.up_proj"`
|
||||
DownProj nn.LinearLayer `weight:"mlp.shared_experts.down_proj"`
|
||||
}
|
||||
|
||||
// Forward applies the shared expert MLP
|
||||
func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array {
|
||||
gate := mlx.SiLU(s.GateProj.Forward(x))
|
||||
up := s.UpProj.Forward(x)
|
||||
return s.DownProj.Forward(mlx.Mul(gate, up))
|
||||
}
|
||||
|
||||
// MoE implements the full Mixture of Experts layer
|
||||
type MoE struct {
|
||||
Gate *MoEGate
|
||||
SwitchMLP *SwitchMLP
|
||||
SharedExperts *SharedExperts
|
||||
}
|
||||
|
||||
// Forward applies the MoE layer
|
||||
func (m *MoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B, L := shape[0], shape[1]
|
||||
|
||||
// Get expert indices and scores
|
||||
inds, scores := m.Gate.Forward(x, cfg)
|
||||
|
||||
// Apply routed experts
|
||||
expertOut := m.SwitchMLP.Forward(x, inds, cfg)
|
||||
|
||||
// Weight by scores: [B, L, topK, D] * [B, L, topK, 1] -> sum over topK
|
||||
scoresExpanded := mlx.ExpandDims(scores, -1)
|
||||
y := mlx.Sum(mlx.Mul(expertOut, scoresExpanded), 2, false)
|
||||
|
||||
// Add shared experts if present
|
||||
if m.SharedExperts != nil {
|
||||
y = mlx.Add(y, m.SharedExperts.Forward(x))
|
||||
}
|
||||
|
||||
return mlx.Reshape(y, B, L, cfg.HiddenSize)
|
||||
}
|
||||
|
||||
// DenseBlock represents a dense transformer block (for first_k_dense_replace layers)
|
||||
type DenseBlock struct {
|
||||
Attention *MLAAttention
|
||||
MLP *DenseMLP
|
||||
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
PostAttentionLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
}
|
||||
|
||||
// Forward applies the dense block
|
||||
func (b *DenseBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
// Pre-norm attention with residual
|
||||
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
|
||||
h := mlx.Add(x, r)
|
||||
|
||||
// Pre-norm MLP with residual
|
||||
r = b.MLP.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps))
|
||||
return mlx.Add(h, r)
|
||||
}
|
||||
|
||||
// MoEBlock represents a MoE transformer block
|
||||
type MoEBlock struct {
|
||||
Attention *MLAAttention
|
||||
MoE *MoE
|
||||
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
PostAttentionLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
}
|
||||
|
||||
// Forward applies the MoE block
|
||||
func (b *MoEBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
// Pre-norm attention with residual
|
||||
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
|
||||
h := mlx.Add(x, r)
|
||||
|
||||
// Pre-norm MoE with residual
|
||||
r = b.MoE.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps), cfg)
|
||||
return mlx.Add(h, r)
|
||||
}
|
||||
|
||||
// Block interface for both dense and MoE blocks
|
||||
type Block interface {
|
||||
Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array
|
||||
}
|
||||
|
||||
// Model represents the complete GLM4-MoE-Lite model
|
||||
type Model struct {
|
||||
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
|
||||
Layers []Block `weight:"-"` // Loaded manually due to different block types
|
||||
Norm *nn.RMSNorm `weight:"model.norm"`
|
||||
LMHead nn.LinearLayer `weight:"lm_head"`
|
||||
|
||||
tok *tokenizer.Tokenizer
|
||||
*Config
|
||||
}
|
||||
|
||||
// computeScale computes the attention scale.
|
||||
// Uses the full key head dimension (qkNopeHeadDim + qkRopeHeadDim) to match the Ollama runner.
|
||||
func computeScale(cfg *Config) float32 {
|
||||
keyLength := cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
|
||||
scale := float32(1.0 / math.Sqrt(float64(keyLength)))
|
||||
if cfg.RopeScaling != nil && cfg.RopeScaling.MscaleAllDim > 0 && cfg.RopeScaling.Factor > 1 {
|
||||
s := 0.1*cfg.RopeScaling.MscaleAllDim*float32(math.Log(float64(cfg.RopeScaling.Factor))) + 1.0
|
||||
scale *= s * s
|
||||
}
|
||||
return scale
|
||||
}
|
||||
|
||||
// supportsGatherQMM returns true if the quantization mode has GatherQMM kernel support.
|
||||
// Currently only 4-bit and 8-bit affine quantization are supported.
|
||||
func supportsGatherQMM(mode string, bits int) bool {
|
||||
return mode == "affine" && (bits == 4 || bits == 8)
|
||||
}
|
||||
|
||||
// ExpertWeight holds a single expert's weight with optional quantization components.
|
||||
type ExpertWeight struct {
|
||||
Weight *mlx.Array // Quantized weight (if quantized) or dequantized weight
|
||||
Scales *mlx.Array // Quantization scales (nil if not quantized)
|
||||
Biases *mlx.Array // Quantization biases (nil if not quantized or mode doesn't use biases)
|
||||
Bits int // Quantization bits (4 or 8), 0 if not quantized
|
||||
GroupSize int // Quantization group size, 0 if not quantized
|
||||
}
|
||||
|
||||
// getQuantParams returns quantization parameters from model metadata.
|
||||
// Returns groupSize, bits, and mode for the model's quantization type.
|
||||
func getQuantParams(weights safetensors.WeightSource) (groupSize, bits int, mode string) {
|
||||
groupSize, bits, mode = safetensors.QuantizationParams(weights.Quantization())
|
||||
// Use metadata group_size if available (overrides default)
|
||||
if gs := weights.GroupSize(); gs > 0 {
|
||||
groupSize = gs
|
||||
}
|
||||
return groupSize, bits, mode
|
||||
}
|
||||
|
||||
// loadExpertWeight loads an expert weight.
|
||||
// If useQuantized is true and the weight is quantized with a supported mode, returns quantized components.
|
||||
// Otherwise dequantizes and returns only the weight.
|
||||
func loadExpertWeight(weights safetensors.WeightSource, path string, useQuantized bool, cfg *Config) *ExpertWeight {
|
||||
w, _ := weights.GetTensor(path + ".weight")
|
||||
if w == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if this is a quantized weight by looking for scales
|
||||
scalePath := path + ".weight_scale"
|
||||
if weights.HasTensor(scalePath) {
|
||||
scales, _ := weights.GetTensor(scalePath)
|
||||
var qbiases *mlx.Array
|
||||
qbiasPath := path + ".weight_qbias"
|
||||
if weights.HasTensor(qbiasPath) {
|
||||
qbiases, _ = weights.GetTensor(qbiasPath)
|
||||
}
|
||||
|
||||
// Get quantization params from metadata
|
||||
groupSize, bits, mode := getQuantParams(weights)
|
||||
|
||||
// Update config with group size (for GatherQMM calls)
|
||||
if cfg.QuantGroupSize == 0 {
|
||||
cfg.QuantGroupSize = groupSize
|
||||
}
|
||||
|
||||
// If GatherQMM is supported and requested, return quantized components
|
||||
if useQuantized && supportsGatherQMM(mode, bits) {
|
||||
return &ExpertWeight{Weight: w, Scales: scales, Biases: qbiases, Bits: bits, GroupSize: groupSize}
|
||||
}
|
||||
|
||||
// Otherwise dequantize
|
||||
return &ExpertWeight{Weight: mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)}
|
||||
}
|
||||
|
||||
return &ExpertWeight{Weight: w}
|
||||
}
|
||||
|
||||
// sanitizeMLAWeights transforms kv_b_proj weights into absorbed MLA format.
|
||||
// Returns embed_q and unembed_out weights for per-head projections.
|
||||
//
|
||||
// kv_b_proj.weight shape: [num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
|
||||
// Output:
|
||||
// - embed_q: [num_heads, kv_lora_rank, qk_nope_head_dim] - projects q_nope to latent
|
||||
// - unembed_out: [num_heads, v_head_dim, kv_lora_rank] - projects latent to output
|
||||
func sanitizeMLAWeights(weights safetensors.WeightSource, prefix string, cfg *Config) (*mlx.Array, *mlx.Array) {
|
||||
path := prefix + ".self_attn.kv_b_proj"
|
||||
w, err := weights.GetTensor(path + ".weight")
|
||||
if err != nil || w == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Check if quantized and dequantize
|
||||
scalePath := path + ".weight_scale"
|
||||
if weights.HasTensor(scalePath) {
|
||||
scales, _ := weights.GetTensor(scalePath)
|
||||
var qbiases *mlx.Array
|
||||
qbiasPath := path + ".weight_qbias"
|
||||
if weights.HasTensor(qbiasPath) {
|
||||
qbiases, _ = weights.GetTensor(qbiasPath)
|
||||
}
|
||||
|
||||
groupSize, bits, mode := getQuantParams(weights)
|
||||
w = mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)
|
||||
}
|
||||
|
||||
// w: [num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
|
||||
// Reshape to [num_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank]
|
||||
headDim := cfg.QKNopeHeadDim + cfg.VHeadDim
|
||||
w = mlx.Reshape(w, cfg.NumAttentionHeads, headDim, cfg.KVLoraRank)
|
||||
|
||||
// Split into wk and wv
|
||||
// wk: [num_heads, qk_nope_head_dim, kv_lora_rank]
|
||||
// wv: [num_heads, v_head_dim, kv_lora_rank]
|
||||
wk := mlx.Slice(w, []int32{0, 0, 0}, []int32{cfg.NumAttentionHeads, cfg.QKNopeHeadDim, cfg.KVLoraRank})
|
||||
wv := mlx.Slice(w, []int32{0, cfg.QKNopeHeadDim, 0}, []int32{cfg.NumAttentionHeads, headDim, cfg.KVLoraRank})
|
||||
|
||||
// Transform for absorbed MLA:
|
||||
// embed_q: transpose(wk) -> [num_heads, kv_lora_rank, qk_nope_head_dim]
|
||||
// This allows: q_nope @ embed_q.T = q_nope @ wk (absorbed key projection)
|
||||
embedQ := mlx.Transpose(wk, 0, 2, 1)
|
||||
|
||||
// unembed_out: wv stays [num_heads, v_head_dim, kv_lora_rank]
|
||||
// This allows: latent_out @ unembed_out.T = latent_out @ wv.T (absorbed value projection)
|
||||
unembedOut := wv
|
||||
|
||||
return embedQ, unembedOut
|
||||
}
|
||||
|
||||
// StackedExpertWeights holds stacked weights for all experts.
|
||||
type StackedExpertWeights struct {
|
||||
Weight *mlx.Array // Stacked weights [num_experts, out, in] or [num_experts, out, in_packed]
|
||||
Scales *mlx.Array // Stacked scales (nil if not quantized)
|
||||
Biases *mlx.Array // Stacked biases (nil if not quantized)
|
||||
Bits int // Quantization bits (4 or 8), 0 if not quantized
|
||||
GroupSize int // Quantization group size, 0 if not quantized
|
||||
}
|
||||
|
||||
// collectAndStackExpertWeights loads and stacks expert weights for one projection type.
|
||||
func collectAndStackExpertWeights(
|
||||
weights safetensors.WeightSource,
|
||||
prefix string,
|
||||
projName string,
|
||||
numExperts int32,
|
||||
useQuantized bool,
|
||||
cfg *Config,
|
||||
) *StackedExpertWeights {
|
||||
var w, s, b []*mlx.Array
|
||||
var bits, groupSize int
|
||||
|
||||
for e := int32(0); e < numExperts; e++ {
|
||||
path := fmt.Sprintf("%s.mlp.experts.%d.%s", prefix, e, projName)
|
||||
ew := loadExpertWeight(weights, path, useQuantized, cfg)
|
||||
if ew == nil {
|
||||
continue
|
||||
}
|
||||
w = append(w, ew.Weight)
|
||||
if ew.Scales != nil {
|
||||
s = append(s, ew.Scales)
|
||||
}
|
||||
if ew.Biases != nil {
|
||||
b = append(b, ew.Biases)
|
||||
}
|
||||
if e == 0 {
|
||||
bits = ew.Bits
|
||||
groupSize = ew.GroupSize
|
||||
}
|
||||
}
|
||||
|
||||
result := &StackedExpertWeights{Bits: bits, GroupSize: groupSize}
|
||||
if len(w) > 0 {
|
||||
result.Weight = mlx.Stack(w, 0)
|
||||
if len(s) > 0 {
|
||||
result.Scales = mlx.Stack(s, 0)
|
||||
}
|
||||
if len(b) > 0 {
|
||||
result.Biases = mlx.Stack(b, 0)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// sanitizeExpertWeights stacks individual expert weights into tensors.
|
||||
// If useQuantized is true and weights support GatherQMM, returns quantized components.
|
||||
// Otherwise returns dequantized weights with nil scales/biases.
|
||||
// Bits and GroupSize are detected per-weight to support mixed-precision (Q4 for gate/up, Q8 for down).
|
||||
func sanitizeExpertWeights(weights safetensors.WeightSource, prefix string, numExperts int32, useQuantized bool, cfg *Config) (gate, up, down *StackedExpertWeights) {
|
||||
gate = collectAndStackExpertWeights(weights, prefix, "gate_proj", numExperts, useQuantized, cfg)
|
||||
up = collectAndStackExpertWeights(weights, prefix, "up_proj", numExperts, useQuantized, cfg)
|
||||
down = collectAndStackExpertWeights(weights, prefix, "down_proj", numExperts, useQuantized, cfg)
|
||||
return gate, up, down
|
||||
}
|
||||
|
||||
// LoadFromManifest loads a GLM4-MoE-Lite model from a manifest (Ollama blob storage).
|
||||
func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) {
|
||||
// Read config from manifest
|
||||
configData, err := manifest.ReadConfig("config.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(configData, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
// Compute derived fields
|
||||
cfg.QHeadDim = cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
|
||||
cfg.Scale = computeScale(&cfg)
|
||||
|
||||
// Load weights from manifest blobs
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
|
||||
if err := weights.Load(0); err != nil {
|
||||
return nil, fmt.Errorf("load weight data: %w", err)
|
||||
}
|
||||
|
||||
// Set up quantization parameters (only if model is actually quantized)
|
||||
// Note: QuantGroupSize will be detected dynamically from tensor shapes during weight loading
|
||||
quantization := weights.Quantization()
|
||||
useQuantized := false
|
||||
if quantization != "" {
|
||||
_, cfg.QuantBits, cfg.QuantMode = safetensors.QuantizationParams(quantization)
|
||||
useQuantized = supportsGatherQMM(cfg.QuantMode, cfg.QuantBits)
|
||||
}
|
||||
|
||||
// Load tokenizer from manifest with config files for EOS token detection
|
||||
tokData, err := manifest.ReadConfig("tokenizer.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer config: %w", err)
|
||||
}
|
||||
|
||||
// Build tokenizer config with companion files for EOS/BOS token loading
|
||||
tokConfig := &tokenizer.TokenizerConfig{
|
||||
ConfigJSON: configData, // Already loaded above, contains eos_token_id
|
||||
}
|
||||
|
||||
// Try to load generation_config.json if available (preferred source for EOS)
|
||||
if genConfigData, err := manifest.ReadConfig("generation_config.json"); err == nil {
|
||||
tokConfig.GenerationConfigJSON = genConfigData
|
||||
}
|
||||
|
||||
// Try to load tokenizer_config.json if available
|
||||
if tokConfigData, err := manifest.ReadConfig("tokenizer_config.json"); err == nil {
|
||||
tokConfig.TokenizerConfigJSON = tokConfigData
|
||||
}
|
||||
|
||||
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse tokenizer: %w", err)
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
Layers: make([]Block, cfg.NumHiddenLayers),
|
||||
Config: &cfg,
|
||||
tok: tok,
|
||||
}
|
||||
|
||||
// Load embedding, norm, and lm_head
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Load layers manually due to different block types
|
||||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
||||
prefix := fmt.Sprintf("model.layers.%d", i)
|
||||
|
||||
// Load attention (same for both block types)
|
||||
attn := &MLAAttention{}
|
||||
if err := safetensors.LoadModule(attn, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d attention: %w", i, err)
|
||||
}
|
||||
|
||||
// Sanitize MLA weights for absorbed attention
|
||||
embedQ, unembedOut := sanitizeMLAWeights(weights, prefix, &cfg)
|
||||
attn.EmbedQ = nn.NewMultiLinear(embedQ)
|
||||
attn.UnembedOut = nn.NewMultiLinear(unembedOut)
|
||||
|
||||
if i < cfg.FirstKDenseReplace {
|
||||
// Dense block
|
||||
block := &DenseBlock{Attention: attn}
|
||||
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d dense: %w", i, err)
|
||||
}
|
||||
m.Layers[i] = block
|
||||
} else {
|
||||
// MoE block
|
||||
block := &MoEBlock{Attention: attn}
|
||||
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d moe block: %w", i, err)
|
||||
}
|
||||
|
||||
// Stack expert weights (pass cfg so group sizes can be detected)
|
||||
gate, up, down := sanitizeExpertWeights(weights, prefix, cfg.NRoutedExperts, useQuantized, &cfg)
|
||||
|
||||
switchMLP := &SwitchMLP{UseQuantized: useQuantized}
|
||||
if useQuantized {
|
||||
switchMLP.GateWeightQ = gate.Weight
|
||||
switchMLP.GateScales = gate.Scales
|
||||
switchMLP.GateBiases = gate.Biases
|
||||
switchMLP.GateBits = gate.Bits
|
||||
switchMLP.GateGroupSize = gate.GroupSize
|
||||
switchMLP.UpWeightQ = up.Weight
|
||||
switchMLP.UpScales = up.Scales
|
||||
switchMLP.UpBiases = up.Biases
|
||||
switchMLP.UpBits = up.Bits
|
||||
switchMLP.UpGroupSize = up.GroupSize
|
||||
switchMLP.DownWeightQ = down.Weight
|
||||
switchMLP.DownScales = down.Scales
|
||||
switchMLP.DownBiases = down.Biases
|
||||
switchMLP.DownBits = down.Bits
|
||||
switchMLP.DownGroupSize = down.GroupSize
|
||||
} else {
|
||||
switchMLP.GateWeight = gate.Weight
|
||||
switchMLP.UpWeight = up.Weight
|
||||
switchMLP.DownWeight = down.Weight
|
||||
}
|
||||
|
||||
block.MoE = &MoE{
|
||||
Gate: &MoEGate{},
|
||||
SwitchMLP: switchMLP,
|
||||
}
|
||||
|
||||
// Load gate weights
|
||||
if err := safetensors.LoadModule(block.MoE.Gate, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d gate: %w", i, err)
|
||||
}
|
||||
|
||||
// Load shared experts if present
|
||||
if cfg.NSharedExperts > 0 {
|
||||
block.MoE.SharedExperts = &SharedExperts{}
|
||||
if err := safetensors.LoadModule(block.MoE.SharedExperts, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d shared experts: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
m.Layers[i] = block
|
||||
}
|
||||
}
|
||||
|
||||
mlx.Eval(mlx.Collect(m)...)
|
||||
weights.ReleaseAll()
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Forward computes the forward pass of the model
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
B, L := tokens.Shape()[0], tokens.Shape()[1]
|
||||
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
var c cache.Cache
|
||||
if caches != nil {
|
||||
c = caches[i]
|
||||
}
|
||||
h = layer.Forward(h, c, B, L, m.Config)
|
||||
}
|
||||
|
||||
h = m.Norm.Forward(h, m.RMSNormEps)
|
||||
return m.LMHead.Forward(h)
|
||||
}
|
||||
|
||||
// Interface methods
|
||||
|
||||
// NumLayers returns the number of transformer layers
|
||||
func (m *Model) NumLayers() int { return len(m.Layers) }
|
||||
|
||||
// MaxContextLength returns the maximum context length
|
||||
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
|
||||
|
||||
// VocabSize returns the vocabulary size
|
||||
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
||||
|
||||
// Tokenizer returns the model's tokenizer
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
|
||||
|
||||
// NewCache creates a new KV cache for the model
|
||||
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
for i := range caches {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
return caches
|
||||
}
|
||||
|
||||
// FormatPrompt applies the GLM-4 chat template with thinking enabled by default.
|
||||
// This follows the GLM-4.7 format with <think> tag for reasoning mode.
|
||||
func (m *Model) FormatPrompt(prompt string) string {
|
||||
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|><think>"
|
||||
}
|
||||
|
||||
// FormatPromptWithThinking applies the GLM-4 chat template with explicit thinking control.
|
||||
// When think is true, the prompt ends with <think> to enable reasoning mode.
|
||||
// When think is false, the prompt ends with </think> to skip reasoning.
|
||||
func (m *Model) FormatPromptWithThinking(prompt string, think bool) string {
|
||||
if think {
|
||||
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|><think>"
|
||||
}
|
||||
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|></think>"
|
||||
}
|
||||
|
||||
// NewRenderer returns a new Renderer for formatting multi-turn conversations.
|
||||
func (m *Model) NewRenderer() *Renderer {
|
||||
return &Renderer{}
|
||||
}
|
||||
|
||||
// NewParser returns a new Parser for extracting thinking and tool calls from output.
|
||||
func (m *Model) NewParser() *Parser {
|
||||
return &Parser{}
|
||||
}
|
||||
479
x/imagegen/models/glm4_moe_lite/parser.go
Normal file
479
x/imagegen/models/glm4_moe_lite/parser.go
Normal file
@@ -0,0 +1,479 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm4_moe_lite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type parserState int
|
||||
|
||||
const (
|
||||
parserState_LookingForThinkingOpen parserState = iota
|
||||
parserState_ThinkingStartedEatingWhitespace
|
||||
parserState_CollectingThinking
|
||||
parserState_ThinkingDoneEatingWhitespace
|
||||
parserState_CollectingContent
|
||||
parserState_ToolStartedEatingWhitespace
|
||||
parserState_CollectingToolContent
|
||||
)
|
||||
|
||||
const (
|
||||
thinkingOpenTag = "<think>"
|
||||
thinkingCloseTag = "</think>"
|
||||
toolOpenTag = "<tool_call>"
|
||||
toolCloseTag = "</tool_call>"
|
||||
)
|
||||
|
||||
// Parser parses GLM4-MoE-Lite model output to extract thinking and tool calls.
|
||||
// GLM-4's prompt ends with <think> when thinking is enabled, so the parser
|
||||
// must start in CollectingThinking state (the model outputs thinking content directly).
|
||||
type Parser struct {
|
||||
state parserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
}
|
||||
|
||||
// HasToolSupport returns true as GLM4 supports tool calling.
|
||||
func (p *Parser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// HasThinkingSupport returns true as GLM4 supports thinking mode.
|
||||
func (p *Parser) HasThinkingSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Init initializes the parser with tools and thinking configuration.
|
||||
func (p *Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
// When thinking is enabled (nil or true), the prompt ends with <think>,
|
||||
// so model output starts directly with thinking content (no opening tag).
|
||||
if thinkValue == nil || thinkValue.Bool() {
|
||||
p.state = parserState_CollectingThinking
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
type parserEvent interface {
|
||||
isParserEvent()
|
||||
}
|
||||
|
||||
type eventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (eventContent) isParserEvent() {}
|
||||
|
||||
type eventRawToolCall struct {
|
||||
raw string
|
||||
}
|
||||
|
||||
func (eventRawToolCall) isParserEvent() {}
|
||||
|
||||
type eventThinkingContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (eventThinkingContent) isParserEvent() {}
|
||||
|
||||
// Add processes new output text and returns parsed content, thinking, and tool calls.
|
||||
func (p *Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents()
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
var contentSb strings.Builder
|
||||
var thinkingSb strings.Builder
|
||||
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case eventRawToolCall:
|
||||
toolCall, err := parseToolCall(event, p.tools)
|
||||
if err != nil {
|
||||
slog.Warn("glm-4 tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
case eventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
case eventContent:
|
||||
contentSb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), thinkingSb.String(), toolCalls, nil
|
||||
}
|
||||
|
||||
func (p *Parser) parseEvents() []parserEvent {
|
||||
var all []parserEvent
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []parserEvent
|
||||
events, keepLooping = p.eat()
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(all) > 0 {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "glm-4 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
// eatLeadingWhitespaceAndTransitionTo consumes leading whitespace from the buffer
|
||||
// and transitions to the next state. Returns (nil, false) if only whitespace remains
|
||||
// in the buffer (needs more input), or (nil, true) if we successfully transitioned.
|
||||
func (p *Parser) eatLeadingWhitespaceAndTransitionTo(nextState parserState) ([]parserEvent, bool) {
|
||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
if trimmed == "" {
|
||||
return nil, false // Still only whitespace, keep waiting for more input
|
||||
}
|
||||
p.state = nextState
|
||||
p.buffer.WriteString(trimmed)
|
||||
return nil, true // Successfully transitioned
|
||||
}
|
||||
|
||||
// splitAtTag splits the buffer at the given tag, returns the content before (trimmed of trailing whitespace),
|
||||
// the content after (optionally trimmed of leading whitespace), and updates the buffer
|
||||
func (p *Parser) splitAtTag(tag string, trimAfter bool) (string, string) {
|
||||
split := strings.SplitN(p.buffer.String(), tag, 2)
|
||||
before := split[0]
|
||||
before = strings.TrimRightFunc(before, unicode.IsSpace)
|
||||
after := split[1]
|
||||
if trimAfter {
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
}
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
return before, after
|
||||
}
|
||||
|
||||
func (p *Parser) eat() ([]parserEvent, bool) {
|
||||
var events []parserEvent
|
||||
|
||||
switch p.state {
|
||||
case parserState_LookingForThinkingOpen:
|
||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||
if strings.HasPrefix(trimmed, thinkingOpenTag) {
|
||||
// Found <think> opening tag
|
||||
after := strings.TrimPrefix(trimmed, thinkingOpenTag)
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
if after == "" {
|
||||
p.state = parserState_ThinkingStartedEatingWhitespace
|
||||
} else {
|
||||
p.state = parserState_CollectingThinking
|
||||
}
|
||||
return events, true
|
||||
} else if strings.HasPrefix(thinkingOpenTag, trimmed) {
|
||||
// Partial opening tag seen, keep accumulating
|
||||
return events, false
|
||||
} else if trimmed == "" {
|
||||
// Only whitespace, keep accumulating
|
||||
return events, false
|
||||
} else {
|
||||
// No thinking tag found, skip to content collection
|
||||
p.state = parserState_CollectingContent
|
||||
// Don't trim - we want to keep the original content
|
||||
return events, true
|
||||
}
|
||||
|
||||
case parserState_ThinkingStartedEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingThinking)
|
||||
|
||||
case parserState_CollectingThinking:
|
||||
acc := p.buffer.String()
|
||||
if strings.Contains(acc, thinkingCloseTag) {
|
||||
thinking, remaining := p.splitAtTag(thinkingCloseTag, true)
|
||||
if len(thinking) > 0 {
|
||||
events = append(events, eventThinkingContent{content: thinking})
|
||||
}
|
||||
if remaining == "" {
|
||||
p.state = parserState_ThinkingDoneEatingWhitespace
|
||||
} else {
|
||||
p.state = parserState_CollectingContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(acc, thinkingCloseTag); overlapLen > 0 {
|
||||
// Partial closing tag - withhold it along with any trailing whitespace before it
|
||||
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWsLen
|
||||
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, eventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
} else {
|
||||
// Pure thinking content - withhold trailing whitespace (might precede closing tag)
|
||||
whitespaceLen := trailingWhitespaceLen(acc)
|
||||
ambiguousStart := len(acc) - whitespaceLen
|
||||
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, eventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
case parserState_ThinkingDoneEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingContent)
|
||||
|
||||
case parserState_CollectingContent:
|
||||
if strings.Contains(p.buffer.String(), toolOpenTag) {
|
||||
before, after := p.splitAtTag(toolOpenTag, true)
|
||||
if len(before) > 0 {
|
||||
events = append(events, eventContent{content: before})
|
||||
}
|
||||
if after == "" {
|
||||
p.state = parserState_ToolStartedEatingWhitespace
|
||||
} else {
|
||||
p.state = parserState_CollectingToolContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(p.buffer.String(), toolOpenTag); overlapLen > 0 {
|
||||
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
|
||||
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWsLen
|
||||
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, eventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
} else {
|
||||
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
|
||||
ambiguousStart := len(p.buffer.String()) - whitespaceLen
|
||||
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, eventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
case parserState_ToolStartedEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingToolContent)
|
||||
|
||||
case parserState_CollectingToolContent:
|
||||
acc := p.buffer.String()
|
||||
if strings.Contains(acc, toolCloseTag) {
|
||||
toolContent, _ := p.splitAtTag(toolCloseTag, true)
|
||||
if len(toolContent) == 0 {
|
||||
slog.Warn("glm4 tool call closing tag found but no content before it")
|
||||
}
|
||||
events = append(events, eventRawToolCall{raw: toolContent})
|
||||
p.state = parserState_CollectingContent
|
||||
return events, true
|
||||
} else {
|
||||
// Keep accumulating - tool calls are not streamed
|
||||
// We just wait for the closing tag
|
||||
return events, false
|
||||
}
|
||||
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
// overlap returns the length of the overlap between the end of s and the start of tag.
|
||||
func overlap(s, tag string) int {
|
||||
for i := 1; i <= len(tag) && i <= len(s); i++ {
|
||||
if strings.HasSuffix(s, tag[:i]) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// trailingWhitespaceLen returns the length of trailing whitespace in s.
|
||||
func trailingWhitespaceLen(s string) int {
|
||||
trimmed := strings.TrimRightFunc(s, unicode.IsSpace)
|
||||
return len(s) - len(trimmed)
|
||||
}
|
||||
|
||||
// ToolCallXML represents the structure of a GLM-4 tool call for XML parsing
|
||||
type ToolCallXML struct {
|
||||
XMLName xml.Name `xml:"tool_call"`
|
||||
Content string `xml:",chardata"` // Function name (text nodes between tags)
|
||||
Keys []string `xml:"arg_key"` // All arg_key elements in document order
|
||||
Values []string `xml:"arg_value"` // All arg_value elements in document order
|
||||
}
|
||||
|
||||
// escapeContent escapes XML entities in text content while preserving arg_key/arg_value tags
|
||||
func escapeContent(s string) string {
|
||||
var result strings.Builder
|
||||
inTag := false
|
||||
|
||||
for i := range len(s) {
|
||||
ch := s[i]
|
||||
|
||||
if ch == '<' {
|
||||
// Check if this is a known tag
|
||||
if strings.HasPrefix(s[i:], "<arg_key>") ||
|
||||
strings.HasPrefix(s[i:], "</arg_key>") ||
|
||||
strings.HasPrefix(s[i:], "<arg_value>") ||
|
||||
strings.HasPrefix(s[i:], "</arg_value>") {
|
||||
inTag = true
|
||||
}
|
||||
}
|
||||
|
||||
if inTag {
|
||||
result.WriteByte(ch)
|
||||
if ch == '>' {
|
||||
inTag = false
|
||||
}
|
||||
} else {
|
||||
// Escape special characters in text content
|
||||
switch ch {
|
||||
case '&':
|
||||
result.WriteString("&")
|
||||
case '<':
|
||||
result.WriteString("<")
|
||||
case '>':
|
||||
result.WriteString(">")
|
||||
default:
|
||||
result.WriteByte(ch)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
func parseToolCall(raw eventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
|
||||
// Escape any unescaped entities in text content
|
||||
escaped := escapeContent(raw.raw)
|
||||
|
||||
// Wrap the content in a root element to make it valid XML
|
||||
xmlString := "<tool_call>" + escaped + "</tool_call>"
|
||||
|
||||
// Parse XML into struct
|
||||
var parsed ToolCallXML
|
||||
if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil {
|
||||
return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err)
|
||||
}
|
||||
|
||||
// Extract and trim function name
|
||||
functionName := strings.TrimSpace(parsed.Content)
|
||||
if functionName == "" {
|
||||
return api.ToolCall{}, fmt.Errorf("empty function name")
|
||||
}
|
||||
|
||||
// Verify keys and values are paired correctly
|
||||
if len(parsed.Keys) != len(parsed.Values) {
|
||||
return api.ToolCall{}, fmt.Errorf("mismatched arg_key and arg_value counts: %d keys, %d values", len(parsed.Keys), len(parsed.Values))
|
||||
}
|
||||
|
||||
// Find the matching tool to get parameter types
|
||||
var matchedTool *api.Tool
|
||||
for i := range tools {
|
||||
if tools[i].Function.Name == functionName {
|
||||
matchedTool = &tools[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Build arguments map by pairing keys and values
|
||||
toolCall := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: functionName,
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
}
|
||||
|
||||
for i := range parsed.Keys {
|
||||
key := strings.TrimSpace(parsed.Keys[i])
|
||||
value := parsed.Values[i] // Don't trim here - parseValue handles it
|
||||
|
||||
// Look up parameter type
|
||||
var paramType api.PropertyType
|
||||
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
|
||||
if prop, ok := matchedTool.Function.Parameters.Properties.Get(key); ok {
|
||||
// Handle anyOf by collecting all types from the union
|
||||
if len(prop.AnyOf) > 0 {
|
||||
for _, anyOfProp := range prop.AnyOf {
|
||||
paramType = append(paramType, anyOfProp.Type...)
|
||||
}
|
||||
} else {
|
||||
paramType = prop.Type
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse value with type coercion
|
||||
toolCall.Function.Arguments.Set(key, parseValue(value, paramType))
|
||||
}
|
||||
|
||||
return toolCall, nil
|
||||
}
|
||||
|
||||
// parseValue parses a string value and coerces it to the appropriate type based on paramType.
|
||||
func parseValue(value string, paramType api.PropertyType) any {
|
||||
value = strings.TrimSpace(value)
|
||||
|
||||
// If no type specified, return as string
|
||||
if len(paramType) == 0 {
|
||||
return value
|
||||
}
|
||||
|
||||
// Try to parse based on specified types
|
||||
for _, t := range paramType {
|
||||
switch t {
|
||||
case "boolean":
|
||||
if value == "true" {
|
||||
return true
|
||||
}
|
||||
if value == "false" {
|
||||
return false
|
||||
}
|
||||
case "integer":
|
||||
var i int64
|
||||
if _, err := fmt.Sscanf(value, "%d", &i); err == nil {
|
||||
return i
|
||||
}
|
||||
case "number":
|
||||
var f float64
|
||||
if _, err := fmt.Sscanf(value, "%f", &f); err == nil {
|
||||
return f
|
||||
}
|
||||
case "array", "object":
|
||||
// Try to parse as JSON
|
||||
var result any
|
||||
if err := json.Unmarshal([]byte(value), &result); err == nil {
|
||||
return result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Default to string
|
||||
return value
|
||||
}
|
||||
192
x/imagegen/models/glm4_moe_lite/parser_test.go
Normal file
192
x/imagegen/models/glm4_moe_lite/parser_test.go
Normal file
@@ -0,0 +1,192 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm4_moe_lite
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestParserThinking(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
thinkEnabled bool
|
||||
wantContent string
|
||||
wantThinking string
|
||||
wantToolCalls int
|
||||
}{
|
||||
{
|
||||
name: "thinking enabled - simple thinking then content",
|
||||
input: "Let me think about this...</think>Here is my answer.",
|
||||
thinkEnabled: true,
|
||||
wantThinking: "Let me think about this...",
|
||||
wantContent: "Here is my answer.",
|
||||
},
|
||||
{
|
||||
name: "thinking enabled - only thinking",
|
||||
input: "I need to consider multiple factors...",
|
||||
thinkEnabled: true,
|
||||
wantThinking: "I need to consider multiple factors...",
|
||||
wantContent: "",
|
||||
},
|
||||
{
|
||||
name: "thinking disabled - direct content",
|
||||
input: "Here is my direct answer.",
|
||||
thinkEnabled: false,
|
||||
wantThinking: "",
|
||||
wantContent: "Here is my direct answer.",
|
||||
},
|
||||
{
|
||||
name: "thinking with tool call",
|
||||
input: "Let me search for that...</think>I'll use a tool.<tool_call>search<arg_key>query</arg_key><arg_value>test</arg_value></tool_call>",
|
||||
thinkEnabled: true,
|
||||
wantThinking: "Let me search for that...",
|
||||
wantContent: "I'll use a tool.",
|
||||
wantToolCalls: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := &Parser{}
|
||||
|
||||
var thinkValue *api.ThinkValue
|
||||
if tt.thinkEnabled {
|
||||
thinkValue = &api.ThinkValue{Value: true}
|
||||
} else {
|
||||
thinkValue = &api.ThinkValue{Value: false}
|
||||
}
|
||||
|
||||
// Define tools for tool call tests
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("query", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Function: api.ToolFunction{
|
||||
Name: "search",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
p.Init(tools, nil, thinkValue)
|
||||
|
||||
content, thinking, calls, err := p.Add(tt.input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if thinking != tt.wantThinking {
|
||||
t.Errorf("thinking = %q, want %q", thinking, tt.wantThinking)
|
||||
}
|
||||
if content != tt.wantContent {
|
||||
t.Errorf("content = %q, want %q", content, tt.wantContent)
|
||||
}
|
||||
if len(calls) != tt.wantToolCalls {
|
||||
t.Errorf("len(calls) = %d, want %d", len(calls), tt.wantToolCalls)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParserToolCall(t *testing.T) {
|
||||
p := &Parser{}
|
||||
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||
props.Set("unit", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Initialize with thinking disabled
|
||||
tv := &api.ThinkValue{Value: false}
|
||||
p.Init(tools, nil, tv)
|
||||
|
||||
input := "<tool_call>get_weather<arg_key>location</arg_key><arg_value>San Francisco</arg_value><arg_key>unit</arg_key><arg_value>celsius</arg_value></tool_call>"
|
||||
|
||||
_, _, calls, err := p.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||
}
|
||||
|
||||
call := calls[0]
|
||||
if call.Function.Name != "get_weather" {
|
||||
t.Errorf("function name = %q, want %q", call.Function.Name, "get_weather")
|
||||
}
|
||||
|
||||
location, ok := call.Function.Arguments.Get("location")
|
||||
if !ok || location != "San Francisco" {
|
||||
t.Errorf("location = %v, want %q", location, "San Francisco")
|
||||
}
|
||||
|
||||
unit, ok := call.Function.Arguments.Get("unit")
|
||||
if !ok || unit != "celsius" {
|
||||
t.Errorf("unit = %v, want %q", unit, "celsius")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOverlap(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
tag string
|
||||
want int
|
||||
}{
|
||||
{"hello<", "</think>", 1},
|
||||
{"hello</", "</think>", 2},
|
||||
{"hello</t", "</think>", 3},
|
||||
{"hello</th", "</think>", 4},
|
||||
{"hello</thi", "</think>", 5},
|
||||
{"hello</thin", "</think>", 6},
|
||||
{"hello</think", "</think>", 7},
|
||||
{"hello</think>", "</think>", 8}, // Complete tag at end returns full length
|
||||
{"hello", "</think>", 0},
|
||||
{"", "</think>", 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.s+"_"+tt.tag, func(t *testing.T) {
|
||||
got := overlap(tt.s, tt.tag)
|
||||
if got != tt.want {
|
||||
t.Errorf("overlap(%q, %q) = %d, want %d", tt.s, tt.tag, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrailingWhitespaceLen(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
want int
|
||||
}{
|
||||
{"hello ", 3},
|
||||
{"hello\n\t ", 3},
|
||||
{"hello", 0},
|
||||
{"", 0},
|
||||
{" ", 3},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.s, func(t *testing.T) {
|
||||
got := trailingWhitespaceLen(tt.s)
|
||||
if got != tt.want {
|
||||
t.Errorf("trailingWhitespaceLen(%q) = %d, want %d", tt.s, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
175
x/imagegen/models/glm4_moe_lite/render.go
Normal file
175
x/imagegen/models/glm4_moe_lite/render.go
Normal file
@@ -0,0 +1,175 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm4_moe_lite
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// Renderer renders messages for GLM4-MoE-Lite models.
|
||||
//
|
||||
// GLM-4 Thinking Modes (ref: https://docs.z.ai/guides/capabilities/thinking-mode):
|
||||
//
|
||||
// 1. INTERLEAVED THINKING
|
||||
// The model thinks between tool calls and after receiving tool results.
|
||||
// This enables complex step-by-step reasoning: interpreting each tool output
|
||||
// before deciding what to do next. Thinking blocks are preserved and returned
|
||||
// with tool results to maintain reasoning continuity.
|
||||
//
|
||||
// 2. PRESERVED THINKING
|
||||
// The model retains reasoning content from previous assistant turns in context.
|
||||
// This preserves reasoning continuity across multi-turn conversations. The
|
||||
// upstream API has a "clear_thinking" parameter to control this:
|
||||
// - clear_thinking=true: clears reasoning from previous turns (outputs </think>)
|
||||
// - clear_thinking=false: preserves <think>...</think> blocks from previous turns
|
||||
//
|
||||
// 3. TURN-LEVEL THINKING
|
||||
// Controls whether the model should reason on each turn. The upstream API
|
||||
// uses "enable_thinking" parameter:
|
||||
// - enable_thinking=true: outputs <think> to start reasoning
|
||||
// - enable_thinking=false: outputs </think> to skip reasoning
|
||||
//
|
||||
// OLLAMA DEFAULTS:
|
||||
// - Thinking is ENABLED by default (thinkValue=nil or true outputs <think>)
|
||||
// - Thinking is PRESERVED by default (reasoning content from previous turns is always
|
||||
// included in <think>...</think> blocks, equivalent to clear_thinking=false)
|
||||
// - Users can disable thinking per-turn via thinkValue=false
|
||||
type Renderer struct{}
|
||||
|
||||
// Render renders messages into the GLM4 chat format.
|
||||
func (r *Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("[gMASK]<sop>")
|
||||
|
||||
if len(tools) > 0 {
|
||||
sb.WriteString("<|system|>\n")
|
||||
sb.WriteString("# Tools\n\n")
|
||||
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
|
||||
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
|
||||
sb.WriteString("<tools>\n")
|
||||
for _, tool := range tools {
|
||||
d, _ := json.Marshal(tool)
|
||||
sb.WriteString(formatToolJSON(d))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("</tools>\n\n")
|
||||
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
|
||||
sb.WriteString("<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call>")
|
||||
}
|
||||
|
||||
think := true
|
||||
if thinkValue != nil && !thinkValue.Bool() {
|
||||
think = false
|
||||
}
|
||||
|
||||
for i, message := range messages {
|
||||
switch message.Role {
|
||||
case "user":
|
||||
sb.WriteString("<|user|>")
|
||||
sb.WriteString(message.Content)
|
||||
case "assistant":
|
||||
sb.WriteString("<|assistant|>")
|
||||
if message.Thinking != "" {
|
||||
sb.WriteString("<think>" + message.Thinking + "</think>")
|
||||
} else {
|
||||
sb.WriteString("</think>")
|
||||
}
|
||||
if message.Content != "" {
|
||||
sb.WriteString(message.Content)
|
||||
}
|
||||
if len(message.ToolCalls) > 0 {
|
||||
for _, toolCall := range message.ToolCalls {
|
||||
sb.WriteString("<tool_call>" + toolCall.Function.Name)
|
||||
sb.WriteString(renderToolArguments(toolCall.Function.Arguments))
|
||||
sb.WriteString("</tool_call>")
|
||||
}
|
||||
}
|
||||
case "tool":
|
||||
if i == 0 || messages[i-1].Role != "tool" {
|
||||
sb.WriteString("<|observation|>")
|
||||
}
|
||||
sb.WriteString("<tool_response>")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("</tool_response>")
|
||||
case "system":
|
||||
sb.WriteString("<|system|>")
|
||||
sb.WriteString(message.Content)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("<|assistant|>")
|
||||
if think {
|
||||
sb.WriteString("<think>")
|
||||
} else {
|
||||
sb.WriteString("</think>")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
// renderToolArguments converts tool call arguments to GLM4 XML format.
|
||||
func renderToolArguments(args api.ToolCallFunctionArguments) string {
|
||||
var sb strings.Builder
|
||||
for key, value := range args.All() {
|
||||
sb.WriteString("<arg_key>" + key + "</arg_key>")
|
||||
var valueStr string
|
||||
if str, ok := value.(string); ok {
|
||||
valueStr = str
|
||||
} else {
|
||||
jsonBytes, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
valueStr = fmt.Sprintf("%v", value)
|
||||
} else {
|
||||
valueStr = string(jsonBytes)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("<arg_value>" + valueStr + "</arg_value>")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// formatToolJSON formats JSON for GLM4 tool definitions by adding spaces after : and ,
|
||||
func formatToolJSON(raw []byte) string {
|
||||
var sb strings.Builder
|
||||
sb.Grow(len(raw) + len(raw)/10)
|
||||
|
||||
inString := false
|
||||
escaped := false
|
||||
for i := range raw {
|
||||
ch := raw[i]
|
||||
sb.WriteByte(ch)
|
||||
|
||||
if inString {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == '"' {
|
||||
inString = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == '"' {
|
||||
inString = true
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == ':' || ch == ',' {
|
||||
sb.WriteByte(' ')
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
205
x/imagegen/models/glm4_moe_lite/render_test.go
Normal file
205
x/imagegen/models/glm4_moe_lite/render_test.go
Normal file
@@ -0,0 +1,205 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm4_moe_lite
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestRendererSimple(t *testing.T) {
|
||||
r := &Renderer{}
|
||||
|
||||
messages := []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
|
||||
// Thinking enabled (default)
|
||||
result, err := r.Render(messages, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
expected := "[gMASK]<sop><|user|>Hello<|assistant|><think>"
|
||||
if result != expected {
|
||||
t.Errorf("result = %q, want %q", result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRendererThinkingDisabled(t *testing.T) {
|
||||
r := &Renderer{}
|
||||
|
||||
messages := []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
|
||||
tv := &api.ThinkValue{Value: false}
|
||||
|
||||
result, err := r.Render(messages, nil, tv)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
expected := "[gMASK]<sop><|user|>Hello<|assistant|></think>"
|
||||
if result != expected {
|
||||
t.Errorf("result = %q, want %q", result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRendererMultiTurn(t *testing.T) {
|
||||
r := &Renderer{}
|
||||
|
||||
messages := []api.Message{
|
||||
{Role: "user", Content: "What is 2+2?"},
|
||||
{Role: "assistant", Content: "4", Thinking: "Let me calculate: 2+2=4"},
|
||||
{Role: "user", Content: "And 3+3?"},
|
||||
}
|
||||
|
||||
result, err := r.Render(messages, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Check key parts
|
||||
if !strings.Contains(result, "[gMASK]<sop>") {
|
||||
t.Error("missing [gMASK]<sop> prefix")
|
||||
}
|
||||
if !strings.Contains(result, "<|user|>What is 2+2?") {
|
||||
t.Error("missing first user message")
|
||||
}
|
||||
if !strings.Contains(result, "<|assistant|><think>Let me calculate: 2+2=4</think>4") {
|
||||
t.Error("missing assistant message with thinking")
|
||||
}
|
||||
if !strings.Contains(result, "<|user|>And 3+3?") {
|
||||
t.Error("missing second user message")
|
||||
}
|
||||
if !strings.HasSuffix(result, "<|assistant|><think>") {
|
||||
t.Errorf("should end with <|assistant|><think>, got suffix: %q", result[len(result)-30:])
|
||||
}
|
||||
}
|
||||
|
||||
func TestRendererWithSystem(t *testing.T) {
|
||||
r := &Renderer{}
|
||||
|
||||
messages := []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
|
||||
result, err := r.Render(messages, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "<|system|>You are a helpful assistant.") {
|
||||
t.Error("missing system message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRendererWithTools(t *testing.T) {
|
||||
r := &Renderer{}
|
||||
|
||||
messages := []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
}
|
||||
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "The city"})
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the weather for a location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: props,
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := r.Render(messages, tools, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Check for tool system prompt
|
||||
if !strings.Contains(result, "<|system|>") {
|
||||
t.Error("missing system tag for tools")
|
||||
}
|
||||
if !strings.Contains(result, "# Tools") {
|
||||
t.Error("missing tools header")
|
||||
}
|
||||
if !strings.Contains(result, "<tools>") {
|
||||
t.Error("missing tools tag")
|
||||
}
|
||||
if !strings.Contains(result, "get_weather") {
|
||||
t.Error("missing tool name")
|
||||
}
|
||||
if !strings.Contains(result, "</tools>") {
|
||||
t.Error("missing closing tools tag")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRendererWithToolCalls(t *testing.T) {
|
||||
r := &Renderer{}
|
||||
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
args.Set("location", "San Francisco")
|
||||
|
||||
messages := []api.Message{
|
||||
{Role: "user", Content: "What's the weather in SF?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "Sunny, 72F"},
|
||||
}
|
||||
|
||||
result, err := r.Render(messages, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "<tool_call>get_weather") {
|
||||
t.Error("missing tool call")
|
||||
}
|
||||
if !strings.Contains(result, "<arg_key>location</arg_key>") {
|
||||
t.Error("missing arg_key")
|
||||
}
|
||||
if !strings.Contains(result, "<arg_value>San Francisco</arg_value>") {
|
||||
t.Error("missing arg_value")
|
||||
}
|
||||
if !strings.Contains(result, "</tool_call>") {
|
||||
t.Error("missing tool call closing tag")
|
||||
}
|
||||
if !strings.Contains(result, "<|observation|>") {
|
||||
t.Error("missing observation tag")
|
||||
}
|
||||
if !strings.Contains(result, "<tool_response>Sunny, 72F</tool_response>") {
|
||||
t.Error("missing tool response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatToolJSON(t *testing.T) {
|
||||
input := []byte(`{"name":"test","value":123}`)
|
||||
result := formatToolJSON(input)
|
||||
|
||||
// Should add spaces after : and ,
|
||||
if !strings.Contains(result, ": ") {
|
||||
t.Error("should add space after colon")
|
||||
}
|
||||
if !strings.Contains(result, ", ") {
|
||||
t.Error("should add space after comma")
|
||||
}
|
||||
}
|
||||
@@ -32,10 +32,16 @@ func NewLinear(weight *mlx.Array, bias *mlx.Array) *Linear {
|
||||
|
||||
// NewQuantizedLinear creates a quantized linear layer directly from bf16 weights.
|
||||
// Quantizes the weight immediately and evaluates to break lazy dependencies.
|
||||
// Note: For modes like "nvfp4", qbiases will be nil.
|
||||
func NewQuantizedLinear(weight *mlx.Array, bias *mlx.Array, groupSize, bits int, mode string) *QuantizedLinear {
|
||||
qw, scales, qbiases := mlx.Quantize(weight, groupSize, bits, mode)
|
||||
// Eval immediately so bf16 weight can be freed
|
||||
mlx.Eval(qw, scales, qbiases)
|
||||
// Handle modes that don't return qbiases (e.g., nvfp4)
|
||||
if qbiases != nil {
|
||||
mlx.Eval(qw, scales, qbiases)
|
||||
} else {
|
||||
mlx.Eval(qw, scales)
|
||||
}
|
||||
return &QuantizedLinear{
|
||||
Weight: qw,
|
||||
Scales: scales,
|
||||
@@ -77,10 +83,13 @@ func (l *Linear) ToQuantized(groupSize, bits int, mode string) *QuantizedLinear
|
||||
|
||||
// QuantizedLinear applies an affine transformation using quantized weights.
|
||||
// Equivalent to mlx.nn.QuantizedLinear.
|
||||
// Supports multiple quantization modes:
|
||||
// - "affine": scale + zero-point bias (QBiases required)
|
||||
// - "nvfp4": NVIDIA FP4 with E4M3 scales (QBiases nil)
|
||||
type QuantizedLinear struct {
|
||||
Weight *mlx.Array // Quantized weight data
|
||||
Scales *mlx.Array // Scale factors for dequantization
|
||||
QBiases *mlx.Array // Quantization biases (NOT layer bias)
|
||||
QBiases *mlx.Array // Quantization biases (NOT layer bias), nil for nvfp4
|
||||
Bias *mlx.Array // Layer bias [output_dims] or nil
|
||||
GroupSize int
|
||||
Bits int
|
||||
@@ -220,3 +229,32 @@ func (ln *LayerNorm) Forward(x *mlx.Array) *mlx.Array {
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// MultiLinearLayer is an interface for per-head linear layers.
|
||||
// This allows swapping between MultiLinear (bf16) and pre-dequantized weights.
|
||||
type MultiLinearLayer interface {
|
||||
Forward(x *mlx.Array) *mlx.Array
|
||||
}
|
||||
|
||||
// MultiLinear performs per-head linear projections.
|
||||
// Weight shape: [num_heads, output_dims, input_dims]
|
||||
// Input shape: [B, num_heads, L, input_dims]
|
||||
// Output shape: [B, num_heads, L, output_dims]
|
||||
type MultiLinear struct {
|
||||
Weight *mlx.Array `weight:"weight"`
|
||||
}
|
||||
|
||||
// NewMultiLinear creates a MultiLinear layer with the given weight.
|
||||
func NewMultiLinear(weight *mlx.Array) *MultiLinear {
|
||||
return &MultiLinear{Weight: weight}
|
||||
}
|
||||
|
||||
// Forward applies per-head linear transformation: x @ weight.T per head via broadcasting.
|
||||
func (ml *MultiLinear) Forward(x *mlx.Array) *mlx.Array {
|
||||
// Weight: [num_heads, output_dims, input_dims]
|
||||
// x: [B, num_heads, L, input_dims]
|
||||
// wT: [num_heads, input_dims, output_dims]
|
||||
// Result: [B, num_heads, L, output_dims]
|
||||
wT := mlx.Transpose(ml.Weight, 0, 2, 1)
|
||||
return mlx.Matmul(x, wT)
|
||||
}
|
||||
|
||||
@@ -1,233 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package runner provides a subprocess server for image generation.
|
||||
// It listens on a port and handles HTTP requests for image generation.
|
||||
package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/flux2"
|
||||
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
||||
)
|
||||
|
||||
// Request is the image generation request format
|
||||
type Request struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
// Response is streamed back for each progress update
|
||||
type Response struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
Image string `json:"image,omitempty"` // Base64-encoded PNG
|
||||
Done bool `json:"done"`
|
||||
Step int `json:"step,omitempty"`
|
||||
Total int `json:"total,omitempty"`
|
||||
}
|
||||
|
||||
// ImageModel is the interface for image generation models
|
||||
type ImageModel interface {
|
||||
GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error)
|
||||
}
|
||||
|
||||
// Server holds the model and handles requests
|
||||
type Server struct {
|
||||
mu sync.Mutex
|
||||
model ImageModel
|
||||
modelName string
|
||||
}
|
||||
|
||||
// Execute is the entry point for the image runner subprocess
|
||||
func Execute(args []string) error {
|
||||
fs := flag.NewFlagSet("image-runner", flag.ExitOnError)
|
||||
modelName := fs.String("model", "", "path to image model")
|
||||
port := fs.Int("port", 0, "port to listen on")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if *modelName == "" {
|
||||
return fmt.Errorf("--model is required")
|
||||
}
|
||||
if *port == 0 {
|
||||
return fmt.Errorf("--port is required")
|
||||
}
|
||||
|
||||
err := mlx.InitMLX()
|
||||
if err != nil {
|
||||
slog.Error("unable to initialize MLX", "error", err)
|
||||
return err
|
||||
}
|
||||
slog.Info("MLX library initialized")
|
||||
slog.Info("starting image runner", "model", *modelName, "port", *port)
|
||||
|
||||
// Check memory requirements before loading
|
||||
requiredMemory := imagegen.EstimateVRAM(*modelName)
|
||||
availableMemory := mlx.GetMemoryLimit()
|
||||
if availableMemory > 0 && availableMemory < requiredMemory {
|
||||
return fmt.Errorf("insufficient memory for image generation: need %d GB, have %d GB",
|
||||
requiredMemory/(1024*1024*1024), availableMemory/(1024*1024*1024))
|
||||
}
|
||||
|
||||
// Detect model type and load appropriate model
|
||||
modelType := imagegen.DetectModelType(*modelName)
|
||||
slog.Info("detected model type", "type", modelType)
|
||||
|
||||
var model ImageModel
|
||||
switch modelType {
|
||||
case "Flux2KleinPipeline":
|
||||
m := &flux2.Model{}
|
||||
if err := m.Load(*modelName); err != nil {
|
||||
return fmt.Errorf("failed to load model: %w", err)
|
||||
}
|
||||
model = m
|
||||
default:
|
||||
// Default to Z-Image for ZImagePipeline, FluxPipeline, etc.
|
||||
m := &zimage.Model{}
|
||||
if err := m.Load(*modelName); err != nil {
|
||||
return fmt.Errorf("failed to load model: %w", err)
|
||||
}
|
||||
model = m
|
||||
}
|
||||
|
||||
server := &Server{
|
||||
model: model,
|
||||
modelName: *modelName,
|
||||
}
|
||||
|
||||
// Set up HTTP handlers
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/health", server.healthHandler)
|
||||
mux.HandleFunc("/completion", server.completionHandler)
|
||||
|
||||
httpServer := &http.Server{
|
||||
Addr: fmt.Sprintf("127.0.0.1:%d", *port),
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
// Handle shutdown
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-sigCh
|
||||
slog.Info("shutting down image runner")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
httpServer.Shutdown(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
slog.Info("image runner listening", "addr", httpServer.Addr)
|
||||
if err := httpServer.ListenAndServe(); err != http.ErrServerClosed {
|
||||
return err
|
||||
}
|
||||
|
||||
<-done
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||
}
|
||||
|
||||
func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req Request
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Serialize generation requests - MLX model may not handle concurrent generation
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Model applies its own defaults for width/height/steps
|
||||
// Only seed needs to be set here if not provided
|
||||
if req.Seed <= 0 {
|
||||
req.Seed = time.Now().UnixNano()
|
||||
}
|
||||
|
||||
// Set up streaming response
|
||||
w.Header().Set("Content-Type", "application/x-ndjson")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "streaming not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate image using the common interface
|
||||
ctx := r.Context()
|
||||
enc := json.NewEncoder(w)
|
||||
|
||||
// Progress callback streams step updates
|
||||
progress := func(step, total int) {
|
||||
resp := Response{Step: step, Total: total}
|
||||
enc.Encode(resp)
|
||||
w.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
img, err := s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress)
|
||||
|
||||
if err != nil {
|
||||
// Don't send error for cancellation
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
resp := Response{Content: fmt.Sprintf("error: %v", err), Done: true}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(data)
|
||||
w.Write([]byte("\n"))
|
||||
return
|
||||
}
|
||||
|
||||
// Encode image as base64 PNG
|
||||
imageData, err := imagegen.EncodeImageBase64(img)
|
||||
if err != nil {
|
||||
resp := Response{Content: fmt.Sprintf("error encoding: %v", err), Done: true}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(data)
|
||||
w.Write([]byte("\n"))
|
||||
return
|
||||
}
|
||||
|
||||
// Free the generated image array and clean up MLX state
|
||||
img.Free()
|
||||
mlx.ClearCache()
|
||||
mlx.MetalResetPeakMemory()
|
||||
|
||||
// Send final response with image data
|
||||
resp := Response{
|
||||
Image: imageData,
|
||||
Done: true,
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(data)
|
||||
w.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
}
|
||||
@@ -17,17 +17,31 @@ type WeightSource interface {
|
||||
GetTensor(name string) (*mlx.Array, error)
|
||||
ListTensors() []string
|
||||
HasTensor(name string) bool
|
||||
Quantization() string // Returns "FP4", "FP8", or ""
|
||||
Quantization() string // Returns "NVFP4", "Q4", "Q8", or ""
|
||||
GroupSize() int // Returns quantization group size, or 0 if not specified
|
||||
}
|
||||
|
||||
// quantizationParams returns groupSize, bits, mode for a quantization type.
|
||||
// Returns defaults (32, 8, "affine") for unknown types (backward compatibility).
|
||||
func quantizationParams(quantization string) (groupSize, bits int, mode string) {
|
||||
// QuantizationParams returns groupSize, bits, mode for a quantization type.
|
||||
// MLX quantization modes:
|
||||
// - "affine": scale + zero-point bias, group_size=32/64/128
|
||||
// - "nvfp4": NVIDIA FP4 with E4M3 scales, group_size=16 (no bias)
|
||||
// - "mxfp8": Microsoft MX FP8 with E4M3 scales, group_size=32 (no bias)
|
||||
func QuantizationParams(quantization string) (groupSize, bits int, mode string) {
|
||||
switch strings.ToUpper(quantization) {
|
||||
case "FP4":
|
||||
case "NVFP4":
|
||||
// NVIDIA FP4: group_size=16, bits=4, E4M3 scales (no qbias)
|
||||
return 16, 4, "nvfp4"
|
||||
case "FP4", "Q4", "INT4":
|
||||
// 4-bit quantization with affine mode (scale + qbias)
|
||||
return 32, 4, "affine"
|
||||
case "MXFP8":
|
||||
// Microsoft MX FP8: group_size=32, bits=8, E4M3 scales (no qbias)
|
||||
return 32, 8, "mxfp8"
|
||||
case "FP8", "Q8", "INT8", "":
|
||||
// 8-bit quantization with affine mode (default for quantized models)
|
||||
return 64, 8, "affine"
|
||||
default:
|
||||
return 32, 8, "affine" // FP8 or unknown
|
||||
return 32, 8, "affine" // Default to affine
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,7 +136,8 @@ func loadStruct(v reflect.Value, weights WeightSource, prefix string, errs *[]st
|
||||
}
|
||||
|
||||
// Handle nn.LinearLayer interface fields specially
|
||||
if field.Type == reflect.TypeOf((*nn.LinearLayer)(nil)).Elem() {
|
||||
linearLayerType := reflect.TypeOf((*nn.LinearLayer)(nil)).Elem()
|
||||
if field.Type == linearLayerType {
|
||||
if !hasTag {
|
||||
continue // no tag = skip
|
||||
}
|
||||
@@ -137,6 +152,23 @@ func loadStruct(v reflect.Value, weights WeightSource, prefix string, errs *[]st
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle nn.MultiLinearLayer interface fields specially
|
||||
multiLinearLayerType := reflect.TypeOf((*nn.MultiLinearLayer)(nil)).Elem()
|
||||
if field.Type == multiLinearLayerType {
|
||||
if !hasTag {
|
||||
continue // no tag = skip
|
||||
}
|
||||
layer, err := LoadMultiLinearLayer(weights, fullPath)
|
||||
if err != nil {
|
||||
if !optional {
|
||||
*errs = append(*errs, fullPath+": "+err.Error())
|
||||
}
|
||||
continue
|
||||
}
|
||||
fieldVal.Set(reflect.ValueOf(layer))
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle by kind
|
||||
switch fieldVal.Kind() {
|
||||
case reflect.Ptr:
|
||||
@@ -216,12 +248,86 @@ func joinPath(prefix, suffix string) string {
|
||||
return prefix + "." + suffix
|
||||
}
|
||||
|
||||
// LoadMultiLinearLayer loads a per-head linear layer from weights.
|
||||
// Weight shape should be [num_heads, output_dims, input_dims].
|
||||
// If quantized, always dequantizes since batched quantized matmul isn't supported.
|
||||
func LoadMultiLinearLayer(weights WeightSource, path string) (nn.MultiLinearLayer, error) {
|
||||
// Check if this is a quantized layer by looking for scale tensor
|
||||
scalePath := path + ".weight_scale"
|
||||
hasScale := weights.HasTensor(scalePath)
|
||||
|
||||
weight, err := weights.GetTensor(path + ".weight")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load weight %s: %w", path, err)
|
||||
}
|
||||
|
||||
if hasScale {
|
||||
scales, err := weights.GetTensor(scalePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load scales %s: %w", scalePath, err)
|
||||
}
|
||||
|
||||
var qbiases *mlx.Array
|
||||
qbiasPath := path + ".weight_qbias"
|
||||
if weights.HasTensor(qbiasPath) {
|
||||
qbiases, _ = weights.GetTensor(qbiasPath)
|
||||
}
|
||||
|
||||
// Always dequantize for MultiLinear - no batched quantized matmul support
|
||||
// Detect bits from tensor shapes (supports mixed-precision Q4/Q8)
|
||||
weightShape := weight.Shape()
|
||||
scalesShape := scales.Shape()
|
||||
weightCols := int(weightShape[len(weightShape)-1])
|
||||
scalesCols := int(scalesShape[len(scalesShape)-1])
|
||||
|
||||
// Detect quantization from tensor shapes
|
||||
// groupSize = weightCols * packFactor / scalesCols
|
||||
// Note: groupSize4 = 2 * groupSize8 always, so ambiguous cases need metadata
|
||||
groupSize4 := weightCols * 8 / scalesCols
|
||||
groupSize8 := weightCols * 4 / scalesCols
|
||||
|
||||
var bits, groupSize int
|
||||
// Use metadata to help disambiguate when shapes are ambiguous
|
||||
// (e.g., Q4 with group_size=64 has same shapes as Q8 with group_size=32)
|
||||
quantType := strings.ToUpper(weights.Quantization())
|
||||
isQ8Type := quantType == "Q8" || quantType == "FP8" || quantType == "INT8"
|
||||
|
||||
if groupSize4 == 32 {
|
||||
// Unambiguous: Q4 with group_size=32
|
||||
bits = 4
|
||||
groupSize = 32
|
||||
} else if groupSize8 == 64 {
|
||||
// Unambiguous: Q8 with group_size=64
|
||||
bits = 8
|
||||
groupSize = 64
|
||||
} else if groupSize4 == 64 && groupSize8 == 32 {
|
||||
// Ambiguous: could be Q4/gs=64 or Q8/gs=32, use metadata
|
||||
if isQ8Type {
|
||||
bits = 8
|
||||
groupSize = 32
|
||||
} else {
|
||||
bits = 4
|
||||
groupSize = 64
|
||||
}
|
||||
} else {
|
||||
// Fallback: use global quantization params
|
||||
_, bits, _ = QuantizationParams(weights.Quantization())
|
||||
packFactor := 32 / bits
|
||||
groupSize = weightCols * packFactor / scalesCols
|
||||
}
|
||||
weight = mlx.Dequantize(weight, scales, qbiases, groupSize, bits, "affine")
|
||||
}
|
||||
|
||||
return nn.NewMultiLinear(weight), nil
|
||||
}
|
||||
|
||||
// LoadLinearLayer loads a linear layer from weights, automatically detecting if it's quantized.
|
||||
// If {path}.weight_scale exists, dequantizes the weights.
|
||||
// If {path}.weight_scale exists, creates a QuantizedLinear layer (or dequantizes if no kernel support).
|
||||
func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error) {
|
||||
// Check if this is a quantized layer by looking for scale tensor
|
||||
scalePath := path + ".weight_scale"
|
||||
if weights.HasTensor(scalePath) {
|
||||
hasScale := weights.HasTensor(scalePath)
|
||||
if hasScale {
|
||||
weight, err := weights.GetTensor(path + ".weight")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load quantized weight %s: %w", path, err)
|
||||
@@ -245,9 +351,52 @@ func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error)
|
||||
qbiases, _ = weights.GetTensor(qbiasPath)
|
||||
}
|
||||
|
||||
groupSize, bits, mode := quantizationParams(weights.Quantization())
|
||||
// Detect bits from tensor shapes (supports mixed-precision Q4/Q8)
|
||||
weightShape := weight.Shape()
|
||||
scalesShape := scales.Shape()
|
||||
weightCols := int(weightShape[len(weightShape)-1])
|
||||
scalesCols := int(scalesShape[len(scalesShape)-1])
|
||||
|
||||
if mlx.MetalIsAvailable() {
|
||||
// Detect quantization from tensor shapes
|
||||
// groupSize = weightCols * packFactor / scalesCols
|
||||
// Note: groupSize4 = 2 * groupSize8 always, so ambiguous cases need metadata
|
||||
groupSize4 := weightCols * 8 / scalesCols
|
||||
groupSize8 := weightCols * 4 / scalesCols
|
||||
|
||||
var bits, groupSize int
|
||||
mode := "affine"
|
||||
// Use metadata to help disambiguate when shapes are ambiguous
|
||||
// (e.g., Q4 with group_size=64 has same shapes as Q8 with group_size=32)
|
||||
quantType := strings.ToUpper(weights.Quantization())
|
||||
isQ8Type := quantType == "Q8" || quantType == "FP8" || quantType == "INT8"
|
||||
|
||||
if groupSize4 == 32 {
|
||||
// Unambiguous: Q4 with group_size=32
|
||||
bits = 4
|
||||
groupSize = 32
|
||||
} else if groupSize8 == 64 {
|
||||
// Unambiguous: Q8 with group_size=64
|
||||
bits = 8
|
||||
groupSize = 64
|
||||
} else if groupSize4 == 64 && groupSize8 == 32 {
|
||||
// Ambiguous: could be Q4/gs=64 or Q8/gs=32, use metadata
|
||||
if isQ8Type {
|
||||
bits = 8
|
||||
groupSize = 32
|
||||
} else {
|
||||
bits = 4
|
||||
groupSize = 64
|
||||
}
|
||||
} else {
|
||||
// Fallback: use global quantization params
|
||||
_, bits, mode = QuantizationParams(weights.Quantization())
|
||||
packFactor := 32 / bits
|
||||
groupSize = weightCols * packFactor / scalesCols
|
||||
}
|
||||
|
||||
// NVFP4 and MXFP8 don't have native quantized matmul kernels in MLX,
|
||||
// so we always dequantize at load time. Affine modes (FP4, FP8) have kernel support.
|
||||
if mlx.MetalIsAvailable() && mode != "nvfp4" && mode != "mxfp8" {
|
||||
return &nn.QuantizedLinear{
|
||||
Weight: weight,
|
||||
Scales: scales,
|
||||
|
||||
@@ -303,6 +303,11 @@ func (mw *ModelWeights) Quantization() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// GroupSize returns 0 for directory-based weights (use default).
|
||||
func (mw *ModelWeights) GroupSize() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
// ReleaseAll releases all cached native file handles.
|
||||
func (mw *ModelWeights) ReleaseAll() {
|
||||
for path, native := range mw.nativeCache {
|
||||
|
||||
@@ -1,82 +0,0 @@
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestPlatformSupport verifies platform validation works correctly.
|
||||
func TestPlatformSupport(t *testing.T) {
|
||||
err := CheckPlatformSupport()
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
if runtime.GOARCH == "arm64" {
|
||||
// Apple Silicon should be supported
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil error on darwin/arm64, got: %v", err)
|
||||
}
|
||||
} else {
|
||||
// Intel Mac should fail
|
||||
if err == nil {
|
||||
t.Error("Expected error on darwin/amd64 (Intel), got nil")
|
||||
}
|
||||
if err != nil && err.Error() == "" {
|
||||
t.Error("Expected meaningful error message for unsupported platform")
|
||||
}
|
||||
}
|
||||
case "linux", "windows":
|
||||
// Linux/Windows are allowed (CUDA support checked at runtime)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil error on %s, got: %v", runtime.GOOS, err)
|
||||
}
|
||||
default:
|
||||
// Other platforms should fail
|
||||
if err == nil {
|
||||
t.Errorf("Expected error on unsupported platform %s, got nil", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestMemoryRequirementsError verifies memory check returns clear error.
|
||||
func TestMemoryRequirementsError(t *testing.T) {
|
||||
// Test with insufficient memory
|
||||
err := CheckMemoryRequirements("test-model", 8*GB)
|
||||
if err == nil {
|
||||
t.Error("Expected error for insufficient memory (8GB < 21GB default)")
|
||||
}
|
||||
|
||||
// Test with sufficient memory
|
||||
err = CheckMemoryRequirements("test-model", 32*GB)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for sufficient memory (32GB), got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEstimateVRAMReturnsReasonableDefaults verifies VRAM estimates are sensible.
|
||||
func TestEstimateVRAMReturnsReasonableDefaults(t *testing.T) {
|
||||
// Unknown model should return default (21GB)
|
||||
vram := EstimateVRAM("unknown-model")
|
||||
if vram < 10*GB || vram > 100*GB {
|
||||
t.Errorf("VRAM estimate %d GB is outside reasonable range (10-100 GB)", vram/GB)
|
||||
}
|
||||
|
||||
// Verify known pipeline estimates exist and are reasonable
|
||||
for name, estimate := range modelVRAMEstimates {
|
||||
if estimate < 10*GB {
|
||||
t.Errorf("VRAM estimate for %s (%d GB) is suspiciously low", name, estimate/GB)
|
||||
}
|
||||
if estimate > 200*GB {
|
||||
t.Errorf("VRAM estimate for %s (%d GB) is suspiciously high", name, estimate/GB)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestServerInterfaceCompliance verifies Server implements llm.LlamaServer.
|
||||
// This is a compile-time check but we document it as a test.
|
||||
func TestServerInterfaceCompliance(t *testing.T) {
|
||||
// The var _ llm.LlamaServer = (*Server)(nil) line in server.go
|
||||
// ensures compile-time interface compliance.
|
||||
// This test documents that requirement.
|
||||
t.Log("Server implements llm.LlamaServer interface (compile-time checked)")
|
||||
}
|
||||
@@ -20,20 +20,28 @@ type ManifestWeights struct {
|
||||
nativeCache []*mlx.SafetensorsFile // keep native handles alive
|
||||
}
|
||||
|
||||
// LoadWeightsFromManifest creates a weight loader for a component from manifest storage.
|
||||
// LoadWeightsFromManifest creates a weight loader from manifest storage.
|
||||
// If component is empty, loads all tensors (for LLM models).
|
||||
// If component is specified, loads only tensors for that component and strips the prefix.
|
||||
func LoadWeightsFromManifest(manifest *ModelManifest, component string) (*ManifestWeights, error) {
|
||||
layers := manifest.GetTensorLayers(component)
|
||||
if len(layers) == 0 {
|
||||
if component == "" {
|
||||
return nil, fmt.Errorf("no tensor layers found in manifest")
|
||||
}
|
||||
return nil, fmt.Errorf("no tensor layers found for component %q", component)
|
||||
}
|
||||
|
||||
// Strip component prefix from tensor names for model loading
|
||||
// e.g., "text_encoder/model.embed_tokens.weight" -> "model.embed_tokens.weight"
|
||||
prefix := component + "/"
|
||||
tensors := make(map[string]ManifestLayer, len(layers))
|
||||
for _, layer := range layers {
|
||||
tensorName := strings.TrimPrefix(layer.Name, prefix)
|
||||
tensors[tensorName] = layer
|
||||
if component == "" {
|
||||
tensors[layer.Name] = layer
|
||||
} else {
|
||||
tensorName := strings.TrimPrefix(layer.Name, component+"/")
|
||||
tensors[tensorName] = layer
|
||||
}
|
||||
}
|
||||
|
||||
return &ManifestWeights{
|
||||
@@ -48,19 +56,30 @@ func LoadWeightsFromManifest(manifest *ModelManifest, component string) (*Manife
|
||||
// Blobs are stored in safetensors format for native mlx_load_safetensors mmap.
|
||||
// If dtype is non-zero, tensors are converted to the specified dtype.
|
||||
func (mw *ManifestWeights) Load(dtype mlx.Dtype) error {
|
||||
// Track native handles to free after batch eval
|
||||
nativeHandles := make([]*mlx.SafetensorsFile, 0, len(mw.tensors))
|
||||
arrays := make([]*mlx.Array, 0, len(mw.tensors))
|
||||
|
||||
for name, layer := range mw.tensors {
|
||||
path := mw.manifest.BlobPath(layer.Digest)
|
||||
|
||||
// Load blob as safetensors (native mmap, zero-copy)
|
||||
sf, err := mlx.LoadSafetensorsNative(path)
|
||||
if err != nil {
|
||||
// Free any handles we've accumulated
|
||||
for _, h := range nativeHandles {
|
||||
h.Free()
|
||||
}
|
||||
return fmt.Errorf("load %s: %w", name, err)
|
||||
}
|
||||
nativeHandles = append(nativeHandles, sf)
|
||||
|
||||
// Blob contains single tensor named "data"
|
||||
arr := sf.Get("data")
|
||||
if arr == nil {
|
||||
sf.Free()
|
||||
for _, h := range nativeHandles {
|
||||
h.Free()
|
||||
}
|
||||
return fmt.Errorf("tensor 'data' not found in blob for %s", name)
|
||||
}
|
||||
|
||||
@@ -68,11 +87,18 @@ func (mw *ManifestWeights) Load(dtype mlx.Dtype) error {
|
||||
if dtype != 0 && arr.Dtype() != dtype {
|
||||
arr = mlx.AsType(arr, dtype)
|
||||
}
|
||||
// ALWAYS make a contiguous copy to ensure independence from mmap
|
||||
// Make contiguous copy to ensure independence from mmap
|
||||
arr = mlx.Contiguous(arr)
|
||||
mlx.Eval(arr)
|
||||
mw.cache[name] = arr
|
||||
sf.Free() // Safe to free - arr is now an independent copy
|
||||
arrays = append(arrays, arr)
|
||||
}
|
||||
|
||||
// Batch evaluate all tensors at once (much faster than one at a time)
|
||||
mlx.Eval(arrays...)
|
||||
|
||||
// Now safe to free all native handles
|
||||
for _, sf := range nativeHandles {
|
||||
sf.Free()
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -107,18 +133,112 @@ func (mw *ManifestWeights) HasTensor(name string) bool {
|
||||
}
|
||||
|
||||
// Quantization returns the model's quantization type from model_index.json.
|
||||
// Returns empty string if not quantized or unknown.
|
||||
// Returns empty string if not quantized.
|
||||
// Falls back to detecting from tensor names and shapes if not in config.
|
||||
func (mw *ManifestWeights) Quantization() string {
|
||||
if mw.manifest == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Try to read from model_index.json first
|
||||
var index struct {
|
||||
Quantization string `json:"quantization"`
|
||||
}
|
||||
if err := mw.manifest.ReadConfigJSON("model_index.json", &index); err != nil {
|
||||
if err := mw.manifest.ReadConfigJSON("model_index.json", &index); err == nil && index.Quantization != "" {
|
||||
return index.Quantization
|
||||
}
|
||||
|
||||
// Fallback: detect from tensor names
|
||||
// Check if any tensors have _scale suffix (indicates quantization)
|
||||
hasScales := false
|
||||
hasQBias := false
|
||||
for name := range mw.tensors {
|
||||
if strings.HasSuffix(name, ".weight_scale") {
|
||||
hasScales = true
|
||||
}
|
||||
if strings.HasSuffix(name, ".weight_qbias") {
|
||||
hasQBias = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasScales {
|
||||
// No scales = not quantized
|
||||
return ""
|
||||
}
|
||||
return index.Quantization
|
||||
|
||||
// Has scales but no qbias = NVFP4 (or other non-affine mode)
|
||||
if !hasQBias {
|
||||
return "NVFP4"
|
||||
}
|
||||
|
||||
// Has both scales and qbias = affine mode
|
||||
// Need to determine FP4 vs FP8 from tensor shapes
|
||||
// FP4: weight last dim is 1/8 of scales last dim * group_size
|
||||
// FP8: weight last dim is 1/4 of scales last dim * group_size
|
||||
//
|
||||
// For affine mode with group_size=32:
|
||||
// - FP4 (4 bits): 8 elements packed per uint32, so weight_dim = orig_dim / 8
|
||||
// - FP8 (8 bits): 4 elements packed per uint32, so weight_dim = orig_dim / 4
|
||||
// scales_dim = orig_dim / group_size
|
||||
// So: weight_dim / scales_dim = group_size / pack_factor
|
||||
// FP4: ratio = 32/8 = 4
|
||||
// FP8: ratio = 32/4 = 8
|
||||
|
||||
// Find a weight/scale pair to check the ratio
|
||||
for name := range mw.tensors {
|
||||
if !strings.HasSuffix(name, ".weight") || strings.Contains(name, "_scale") || strings.Contains(name, "_qbias") {
|
||||
continue
|
||||
}
|
||||
scaleName := name + "_scale"
|
||||
if _, ok := mw.tensors[scaleName]; !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Load both tensors to check shapes
|
||||
weightLayer := mw.tensors[name]
|
||||
scaleLayer := mw.tensors[scaleName]
|
||||
|
||||
// Get shapes from manifest layer metadata if available
|
||||
// For now, default to FP4 since it's more common
|
||||
// The actual shape check would require loading the tensor
|
||||
|
||||
// Simple heuristic: check if scale tensor is ~4x smaller than weight
|
||||
// FP4: weight is packed 8 per uint32, scales are 1 per group (32)
|
||||
// So scale size should be ~weight_size * 8 / 32 = weight_size / 4
|
||||
// FP8: weight is packed 4 per uint32, scales are 1 per group (32)
|
||||
// So scale size should be ~weight_size * 4 / 32 = weight_size / 8
|
||||
|
||||
// Rough size heuristic (assuming float16 scales)
|
||||
// Q4: scale_bytes ≈ weight_bytes / 4 * 2 / 4 = weight_bytes / 8
|
||||
// Q8: scale_bytes ≈ weight_bytes / 8 * 2 / 4 = weight_bytes / 16
|
||||
ratio := float64(weightLayer.Size) / float64(scaleLayer.Size)
|
||||
if ratio < 12 {
|
||||
// Closer to 8 = Q4
|
||||
return "Q4"
|
||||
}
|
||||
// Closer to 16 = Q8
|
||||
return "Q8"
|
||||
}
|
||||
|
||||
// Default to Q4 for affine mode (most common)
|
||||
return "Q4"
|
||||
}
|
||||
|
||||
// GroupSize returns the quantization group size from model_index.json.
|
||||
// Returns 0 if not specified (caller should use default based on quantization type).
|
||||
func (mw *ManifestWeights) GroupSize() int {
|
||||
if mw.manifest == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
var index struct {
|
||||
GroupSize int `json:"group_size"`
|
||||
}
|
||||
if err := mw.manifest.ReadConfigJSON("model_index.json", &index); err == nil && index.GroupSize > 0 {
|
||||
return index.GroupSize
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// ReleaseAll frees all native handles and clears the tensor cache.
|
||||
|
||||
@@ -1,797 +1,144 @@
|
||||
//go:build mlx
|
||||
|
||||
package kvcache
|
||||
|
||||
// import (
|
||||
// "errors"
|
||||
// "fmt"
|
||||
// "log/slog"
|
||||
// "math"
|
||||
// "slices"
|
||||
|
||||
// "github.com/ollama/ollama/ml"
|
||||
// "github.com/ollama/ollama/model/input"
|
||||
// )
|
||||
|
||||
// type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
|
||||
|
||||
// // Causal cache stores K and V tensors according to their position in the
|
||||
// // sequence. Returns the history and a mask for attending to past tokens
|
||||
// //
|
||||
// // The tensors are of shape embed dim, kv heads, batch size
|
||||
// // The mask is of shape history size, batch size
|
||||
// type Causal struct {
|
||||
// DType ml.DType
|
||||
|
||||
// // swaWindowSize is the number of tokens that will be included in the mask
|
||||
// // during attention operations. swaMemorySize is the number of tokens that
|
||||
// // will be retained in memory for partial prefix caching. Set to math.MaxInt32
|
||||
// // for unlimited or if sliding window attention is not being used.
|
||||
// swaWindowSize int32
|
||||
// swaMemorySize int32
|
||||
|
||||
// chunkSize int32
|
||||
|
||||
// opts CausalOptions
|
||||
|
||||
// // maxBatch is the largest batch that we might receive
|
||||
// maxBatch int
|
||||
|
||||
// // config controls mostly backend-specific optimizations
|
||||
// config *ml.CacheConfig
|
||||
|
||||
// // ** current forward pass **
|
||||
|
||||
// // size of the current batch
|
||||
// curBatchSize int
|
||||
|
||||
// // locations for data storage for this batch
|
||||
// curLoc ml.Tensor
|
||||
|
||||
// // mask of the cache as used by this batch
|
||||
// curMask ml.Tensor
|
||||
|
||||
// // the active layer for Get and Put
|
||||
// curLayer int
|
||||
|
||||
// // locations in the cache that are needed for this batch
|
||||
// curCellRange cellRange
|
||||
|
||||
// // curSequences is the sequences corresponding to this pass's entries in the cache
|
||||
// curSequences []int
|
||||
|
||||
// // curPositions is the positions corresponding to this pass's entries in the cache
|
||||
// curPositions []int32
|
||||
|
||||
// // ** cache metadata **
|
||||
|
||||
// // for each possible location in the cache, stores the position and set of sequences
|
||||
// // that reference the data there
|
||||
// cells []cacheCell
|
||||
|
||||
// // maps from sequence to the range of locations where it is stored in the cache
|
||||
// cellRanges map[int]cellRange
|
||||
|
||||
// // ** cache data storage **
|
||||
|
||||
// shiftFn shiftFn
|
||||
// backend ml.Backend
|
||||
// ctxs map[int]ml.Context
|
||||
// keys, values map[int]ml.Tensor
|
||||
|
||||
// kHeadDims, vHeadDims, numKVHeads map[int]int
|
||||
// }
|
||||
|
||||
// type cacheCell struct {
|
||||
// pos int32
|
||||
// sequences []int
|
||||
// }
|
||||
|
||||
// type cellRange struct {
|
||||
// min int
|
||||
// max int
|
||||
// }
|
||||
|
||||
// func NewCausalCache(shift shiftFn) *Causal {
|
||||
// return &Causal{
|
||||
// shiftFn: shift,
|
||||
// ctxs: make(map[int]ml.Context),
|
||||
// keys: make(map[int]ml.Tensor),
|
||||
// values: make(map[int]ml.Tensor),
|
||||
// kHeadDims: make(map[int]int),
|
||||
// vHeadDims: make(map[int]int),
|
||||
// numKVHeads: make(map[int]int),
|
||||
// }
|
||||
// }
|
||||
|
||||
// func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
||||
// return &Causal{
|
||||
// swaWindowSize: windowSize,
|
||||
// shiftFn: shift,
|
||||
// ctxs: make(map[int]ml.Context),
|
||||
// keys: make(map[int]ml.Tensor),
|
||||
// values: make(map[int]ml.Tensor),
|
||||
// kHeadDims: make(map[int]int),
|
||||
// vHeadDims: make(map[int]int),
|
||||
// numKVHeads: make(map[int]int),
|
||||
// }
|
||||
// }
|
||||
|
||||
// func NewSWAMemCache(windowSize int32, memorySize int32, shift shiftFn) *Causal {
|
||||
// return &Causal{
|
||||
// swaWindowSize: windowSize,
|
||||
// swaMemorySize: memorySize,
|
||||
// shiftFn: shift,
|
||||
// ctxs: make(map[int]ml.Context),
|
||||
// keys: make(map[int]ml.Tensor),
|
||||
// values: make(map[int]ml.Tensor),
|
||||
// kHeadDims: make(map[int]int),
|
||||
// vHeadDims: make(map[int]int),
|
||||
// numKVHeads: make(map[int]int),
|
||||
// }
|
||||
// }
|
||||
|
||||
// func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal {
|
||||
// return &Causal{
|
||||
// chunkSize: chunkSize,
|
||||
// shiftFn: shift,
|
||||
// ctxs: make(map[int]ml.Context),
|
||||
// keys: make(map[int]ml.Tensor),
|
||||
// values: make(map[int]ml.Tensor),
|
||||
// kHeadDims: make(map[int]int),
|
||||
// vHeadDims: make(map[int]int),
|
||||
// numKVHeads: make(map[int]int),
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
// if c.config == nil {
|
||||
// var config ml.CacheConfig
|
||||
// if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||
// config = cc.CacheConfig()
|
||||
// }
|
||||
// c.config = &config
|
||||
// }
|
||||
|
||||
// if c.config.CachePadding == 0 {
|
||||
// c.config.CachePadding = 1
|
||||
// }
|
||||
|
||||
// if c.config.MaskBatchPadding == 0 {
|
||||
// c.config.MaskBatchPadding = 1
|
||||
// }
|
||||
|
||||
// // TODO what types do we handle here?
|
||||
// // if c.config.MaskDType == ml.DTypeOther {
|
||||
// // c.config.MaskDType = ml.DTypeFloat32
|
||||
// // }
|
||||
|
||||
// if c.swaWindowSize == 0 {
|
||||
// c.swaWindowSize = math.MaxInt32
|
||||
// }
|
||||
// if c.swaMemorySize == 0 {
|
||||
// c.swaMemorySize = c.swaWindowSize
|
||||
// }
|
||||
// // We will allocate space in the cache for the stop token, which won't be part of a follow on
|
||||
// // sequence, so allocate an extra token of storage to ensure that we can jump back without
|
||||
// // causing a cache break. As an optimization, only do this when we have parallel sequences
|
||||
// // because the extra token will live in the batch buffer and won't get overwritten if we
|
||||
// // only have a single sequence.
|
||||
// if c.swaMemorySize != math.MaxInt32 && maxSequences > 1 {
|
||||
// c.swaMemorySize = max(c.swaMemorySize, c.swaWindowSize+1)
|
||||
// }
|
||||
// if int(c.swaMemorySize) >= capacity {
|
||||
// c.swaMemorySize = math.MaxInt32
|
||||
// }
|
||||
|
||||
// if c.swaMemorySize < c.swaWindowSize {
|
||||
// panic(fmt.Errorf("sliding window memory (%v) must be at least as large as the window (%v)", c.swaMemorySize, c.swaWindowSize))
|
||||
// }
|
||||
|
||||
// var cacheSize int
|
||||
// if c.swaMemorySize == math.MaxInt32 {
|
||||
// cacheSize = maxSequences * capacity
|
||||
// } else {
|
||||
// cacheSize = (maxSequences * int(c.swaMemorySize)) + maxBatch
|
||||
// }
|
||||
// cacheSize = roundUp(cacheSize, c.config.CachePadding)
|
||||
// c.cells = make([]cacheCell, cacheSize)
|
||||
|
||||
// c.DType = dtype
|
||||
// c.cellRanges = make(map[int]cellRange)
|
||||
// c.backend = backend
|
||||
// c.maxBatch = maxBatch
|
||||
// }
|
||||
|
||||
// func (c *Causal) SetConfig(config ml.CacheConfig) {
|
||||
// if c.config != nil {
|
||||
// panic("config cannot be changed after being previously set, either by the model or backend")
|
||||
// }
|
||||
|
||||
// c.config = &config
|
||||
// }
|
||||
|
||||
// func (c *Causal) Close() {
|
||||
// slog.Info("XXX Causal.Close called", "number of contexts", len(c.ctxs))
|
||||
// for _, ctx := range c.ctxs {
|
||||
// ctx.Close()
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
// slog.Info("XXX Causal.StartForward", "cell count", len(c.cells), "prior batch size", c.curBatchSize, "positions", len(batch.Positions), "reserve", reserve, "batch", batch)
|
||||
// // panic("XXX Causal.StartForward")
|
||||
// c.curBatchSize = len(batch.Positions)
|
||||
// c.curSequences = batch.Sequences
|
||||
// c.curPositions = batch.Positions
|
||||
// c.opts.Except = nil
|
||||
|
||||
// var locs []int32
|
||||
// if !reserve {
|
||||
// c.updateSlidingWindow()
|
||||
|
||||
// var err error
|
||||
// locs, err = c.findLocs()
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// slog.Info("XXX Causal.StartForward", "findLocs len", len(locs))
|
||||
|
||||
// for i, pos := range batch.Positions {
|
||||
// seq := batch.Sequences[i]
|
||||
// loc := int(locs[i])
|
||||
|
||||
// c.cells[loc] = cacheCell{pos: pos, sequences: []int{seq}}
|
||||
|
||||
// seqRange, ok := c.cellRanges[seq]
|
||||
// if !ok {
|
||||
// seqRange = newRange()
|
||||
// }
|
||||
|
||||
// seqRange.min = min(seqRange.min, loc)
|
||||
// c.curCellRange.min = min(c.curCellRange.min, loc)
|
||||
|
||||
// seqRange.max = max(seqRange.max, loc)
|
||||
// c.curCellRange.max = max(c.curCellRange.max, loc)
|
||||
|
||||
// c.cellRanges[seq] = seqRange
|
||||
// }
|
||||
// } else {
|
||||
// // If we are reserving memory, don't update any of the cache metadata but set the size
|
||||
// // to the worst case.
|
||||
// locs = make([]int32, c.curBatchSize)
|
||||
// for i := range locs {
|
||||
// locs[i] = int32(i)
|
||||
// }
|
||||
// c.curCellRange.min = 0
|
||||
// c.curCellRange.max = len(c.cells) - 1
|
||||
// }
|
||||
|
||||
// // XXX Building up the locs for what's already processed (if any)
|
||||
// dummyLocs := []int{}
|
||||
// c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
|
||||
// c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
||||
|
||||
// for i := range c.curBatchSize {
|
||||
// enabled := !slices.Contains(c.opts.Except, i)
|
||||
// for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
||||
// if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
||||
// (enabled && c.cells[j].pos > c.curPositions[i]) ||
|
||||
// c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
|
||||
// c.cells[j].pos < c.curPositions[i]-c.swaWindowSize {
|
||||
// // mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
||||
// } else {
|
||||
// if len(dummyLocs) == 0 || dummyLocs[len(dummyLocs)-1] != i {
|
||||
// dummyLocs = append(dummyLocs, i)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// slog.Info("XXX Causa.StartForward calculated locations", "locs", dummyLocs)
|
||||
|
||||
// slog.Info("XXX Causal.StartForward", "locs", locs)
|
||||
// c.curLoc = ctx.Input().FromInts(locs, len(locs))
|
||||
// c.curMask = c.buildMask(ctx)
|
||||
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// func newRange() cellRange {
|
||||
// return cellRange{
|
||||
// min: math.MaxInt,
|
||||
// max: 0,
|
||||
// }
|
||||
// }
|
||||
|
||||
// // Returns a slice of locations where each token in the batch should be stored
|
||||
// func (c *Causal) findLocs() ([]int32, error) {
|
||||
// loc := make([]int32, 0, c.curBatchSize)
|
||||
|
||||
// for i := range c.cells {
|
||||
// if len(c.cells[i].sequences) == 0 {
|
||||
// loc = append(loc, int32(i))
|
||||
// if len(loc) >= c.curBatchSize {
|
||||
// return loc, nil
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// return nil, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
|
||||
// }
|
||||
|
||||
// func (c *Causal) updateSlidingWindow() {
|
||||
// c.curCellRange = newRange()
|
||||
|
||||
// if c.swaMemorySize == math.MaxInt32 {
|
||||
// for _, seq := range c.curSequences {
|
||||
// if seqRange, ok := c.cellRanges[seq]; ok {
|
||||
// c.curCellRange.min = min(c.curCellRange.min, seqRange.min)
|
||||
// c.curCellRange.max = max(c.curCellRange.max, seqRange.max)
|
||||
// }
|
||||
// }
|
||||
|
||||
// return
|
||||
// }
|
||||
|
||||
// type lowestPosition struct {
|
||||
// pos int32
|
||||
// curBatch bool
|
||||
// }
|
||||
|
||||
// // create a map of unique sequences to the lowest position in that sequence
|
||||
// lowestPos := make(map[int]lowestPosition)
|
||||
// for i := range c.curPositions {
|
||||
// seq := c.curSequences[i]
|
||||
|
||||
// lowest, ok := lowestPos[seq]
|
||||
// if !ok {
|
||||
// lowest = lowestPosition{pos: c.curPositions[i], curBatch: true}
|
||||
// } else if c.curPositions[i] < lowest.pos {
|
||||
// lowest.pos = c.curPositions[i]
|
||||
// }
|
||||
|
||||
// lowestPos[seq] = lowest
|
||||
// }
|
||||
|
||||
// // for any sequences are not part of this batch, clean up any tokens
|
||||
// // that are no longer needed after the processing of the previous
|
||||
// // batch
|
||||
// for seq, seqRange := range c.cellRanges {
|
||||
// if _, ok := lowestPos[seq]; !ok {
|
||||
// var last int32
|
||||
// for i := seqRange.min; i <= seqRange.max; i++ {
|
||||
// if slices.Contains(c.cells[i].sequences, seq) {
|
||||
// last = max(last, c.cells[i].pos)
|
||||
// }
|
||||
// }
|
||||
|
||||
// lowestPos[seq] = lowestPosition{pos: last + 1, curBatch: false}
|
||||
// }
|
||||
// }
|
||||
|
||||
// // delete any entries that are beyond the window of the oldest position in the sequence
|
||||
// for seq, lowest := range lowestPos {
|
||||
// oldRange, ok := c.cellRanges[seq]
|
||||
// if !ok {
|
||||
// continue
|
||||
// }
|
||||
|
||||
// newRange := newRange()
|
||||
|
||||
// for i := oldRange.min; i <= oldRange.max; i++ {
|
||||
// if slices.Contains(c.cells[i].sequences, seq) {
|
||||
// if c.cells[i].pos < lowest.pos-c.swaMemorySize {
|
||||
// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
|
||||
// } else {
|
||||
// newRange.min = min(newRange.min, i)
|
||||
// newRange.max = max(newRange.max, i)
|
||||
// }
|
||||
// if lowest.curBatch && c.cells[i].pos >= lowest.pos-c.swaWindowSize {
|
||||
// c.curCellRange.min = min(c.curCellRange.min, i)
|
||||
// c.curCellRange.max = max(c.curCellRange.max, i)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// c.cellRanges[seq] = newRange
|
||||
// }
|
||||
// }
|
||||
|
||||
// func roundDown(length, pad int) int {
|
||||
// return (length / pad) * pad
|
||||
// }
|
||||
|
||||
// func roundUp(length, pad int) int {
|
||||
// return ((length + pad - 1) / pad) * pad
|
||||
// }
|
||||
|
||||
// // Builds a mask of history x batch indicating whether for each token in the batch the
|
||||
// // token in the history should apply. This is based on both the sequence and causality (the
|
||||
// // position of the history is not ahead of the token in the batch).
|
||||
// func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
||||
// // Align and pad the two dimensions as required by the backend
|
||||
// batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
|
||||
|
||||
// c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
|
||||
// c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
||||
|
||||
// length := c.curCellRange.max - c.curCellRange.min + 1
|
||||
|
||||
// mask := make([]float32, batchSize*length)
|
||||
|
||||
// for i := range c.curBatchSize {
|
||||
// enabled := !slices.Contains(c.opts.Except, i)
|
||||
// for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
||||
// if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
||||
// (enabled && c.cells[j].pos > c.curPositions[i]) ||
|
||||
// c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
|
||||
// c.cells[j].pos < c.curPositions[i]-c.swaWindowSize {
|
||||
// mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// // Mask out any padding tokens we added. For padding that we added to the cache history, this
|
||||
// // has already been masked out because the sequence doesn't match.
|
||||
// for i := c.curBatchSize * length; i < len(mask); i++ {
|
||||
// mask[i] = float32(math.Inf(-1))
|
||||
// }
|
||||
|
||||
// maskTensor := ctx.Input().FromFloats(mask, batchSize, length)
|
||||
|
||||
// // if c.config.MaskDType != ml.DTypeFloat32 {
|
||||
// // maskTensor = maskTensor.Cast(ctx, c.config.MaskDType)
|
||||
// // }
|
||||
|
||||
// slog.Info("XXX Causal.buildMask", "c.curBatchSize", c.curBatchSize, "c.config.MaskBatchPadding", c.config.MaskBatchPadding, "c.curCellRange.min", c.curCellRange.min, "c.curCellRange.max", c.curCellRange.max, "size", len(mask), "shape", []int{1, batchSize, length})
|
||||
|
||||
// return maskTensor
|
||||
// }
|
||||
|
||||
// func (c *Causal) SetLayer(layer int) {
|
||||
// c.curLayer = layer
|
||||
// }
|
||||
|
||||
// type CausalOptions struct {
|
||||
// // Enabled controls whether the causal mask is generated for a particular index in a batch
|
||||
// Except []int
|
||||
// }
|
||||
|
||||
// // SetCausal disables causal mask generation for a particular range of indicies in
|
||||
// // the current batch for subsequent calls to Get. The state resets for the next forward pass.
|
||||
// func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
|
||||
// if !slices.Equal(c.opts.Except, opts.Except) {
|
||||
// c.opts = opts
|
||||
// if ctx != nil {
|
||||
// c.curMask = c.buildMask(ctx)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
// key := c.keys[c.curLayer]
|
||||
// value := c.values[c.curLayer]
|
||||
|
||||
// kHeadDim := c.kHeadDims[c.curLayer]
|
||||
// vHeadDim := c.vHeadDims[c.curLayer]
|
||||
// numKVHeads := c.numKVHeads[c.curLayer]
|
||||
// // rowSize := numKVHeads * c.curBatchSize
|
||||
// // cachedSize := c.curMask.Dim(1)
|
||||
// cachedSize := c.curLoc.Dim(0)
|
||||
// // kCellSize := kHeadDim * numKVHeads
|
||||
// // vCellSize := vHeadDim * numKVHeads
|
||||
|
||||
// slog.Info("XXX Causal.Get full cache", "key", key)
|
||||
// slog.Info("XXX Causal.Get full cache", "value", value)
|
||||
// slog.Info("XXX Causal.Get full cache", "curloc", c.curLoc)
|
||||
// slog.Info("XXX Causal.Get", "curMask", c.curMask)
|
||||
// slog.Info("XXX Causal.Get", "kHeadDim", kHeadDim, "numKVHeads", numKVHeads, "cachedSize", cachedSize, "kHeadDim", kHeadDim)
|
||||
// // panic("XXX")
|
||||
|
||||
// // fmt.Fprintln(os.Stderr, key.ToString())
|
||||
// // panic("full cache value")
|
||||
|
||||
// // TODO we should use TakeAxes to gather the cells from curLoc, but for now to be consistent with GGML, just grab a larger chunk and mask
|
||||
// key = key.TakeAxes(ctx, c.curLoc, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
|
||||
// // key = key.AsStrided(ctx, []int{1, numKVHeads, cachedSize, kHeadDim}, []int{}, rowSize*c.curCellRange.min)
|
||||
|
||||
// // slog.Info("XXX Causal.Get after AsStrided", "key", key)
|
||||
// // panic("XXX")
|
||||
|
||||
// // if c.config.PermutedV {
|
||||
// // panic("permuted")
|
||||
// // // TODO not converted
|
||||
// // vHeadDim := value.Dim(1)
|
||||
// // elemSize := value.Stride(2)
|
||||
|
||||
// // value = value.AsStrided(ctx,
|
||||
// // []int{numKVHeads, vHeadDim, cachedSize},
|
||||
// // []int{value.Stride(0), value.Stride(1)},
|
||||
// // elemSize*c.curCellRange.min,
|
||||
// // )
|
||||
// // } else {
|
||||
// // vHeadDim := c.vHeadDims[c.curLayer]
|
||||
// // rowSize := value.Stride(2)
|
||||
// // slog.Info("XXX Causal.Get before AsStrided", "vHeadDim", vHeadDim, "rowSize", rowSize)
|
||||
// // panic("XXX")
|
||||
|
||||
// // TODO we should use TakeAxes to gather the cells from curLoc, but for now to be consistent with GGML, just grab a larger chunk and mask
|
||||
// value = value.TakeAxes(ctx, c.curLoc, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
|
||||
// // value = value.AsStrided(ctx, []int{1, numKVHeads, cachedSize, vHeadDim}, []int{}, rowSize*c.curCellRange.min)
|
||||
|
||||
// // slog.Info("XXX Causal.Get after AsStrided", "value", value)
|
||||
// // panic("XXX")
|
||||
|
||||
// // }
|
||||
|
||||
// // // TODO The mask changes from X,X to 1,X, and with the Row-order change
|
||||
// // // the 1 becomes trailing and messes up later operations
|
||||
// // // This isn't the right solution, but works around it...
|
||||
// // if c.curMask.Dim(1) == 1 {
|
||||
// // return key, value, c.curMask.Transpose(ctx, 1, 0, 2, 3)
|
||||
// // }
|
||||
// // fmt.Fprintln(os.Stderr, key.ToString())
|
||||
// // fmt.Fprintln(os.Stderr, value.ToString())
|
||||
// // panic("XXX")
|
||||
// slog.Info("XXX Mask", "curLayer", c.curLayer, "shape", c.curMask.Shape())
|
||||
|
||||
// return key, value, c.curMask
|
||||
// }
|
||||
|
||||
// func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
// kHeadDim := key.Dim(3)
|
||||
// vHeadDim := value.Dim(3)
|
||||
// numKVHeads := key.Dim(1)
|
||||
// batchSize := key.Dim(2)
|
||||
// kCellSize := kHeadDim * numKVHeads
|
||||
// vCellSize := vHeadDim * numKVHeads
|
||||
|
||||
// // slog.Info("XXX Causal.Put", "key", key, "value", value)
|
||||
// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize)
|
||||
// // panic("XXX")
|
||||
|
||||
// if c.curBatchSize != batchSize {
|
||||
// panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
|
||||
// }
|
||||
|
||||
// // slog.Info("XXX", "c.ctxs", c.ctxs, "c.curLayer", c.curLayer, "backend", c.backend)
|
||||
// if _, ok := c.ctxs[c.curLayer]; !ok {
|
||||
// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
|
||||
// c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
|
||||
// }
|
||||
|
||||
// if _, ok := c.keys[c.curLayer]; !ok {
|
||||
// slog.Info("XXX Causal.Put allocating keys", "c.curLayer", c.curLayer, "shape", []int{len(c.cells), kCellSize})
|
||||
|
||||
// c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), kCellSize)
|
||||
// c.kHeadDims[c.curLayer] = kHeadDim
|
||||
// c.vHeadDims[c.curLayer] = vHeadDim
|
||||
// c.numKVHeads[c.curLayer] = numKVHeads
|
||||
// }
|
||||
|
||||
// if _, ok := c.values[c.curLayer]; !ok {
|
||||
// // if c.config.PermutedV {
|
||||
// // c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, numKVHeads, vHeadDim, len(c.cells))
|
||||
// // } else {
|
||||
// c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vCellSize)
|
||||
// // }
|
||||
// }
|
||||
|
||||
// key = key.Reshape(ctx, batchSize, 1, kCellSize) //.Contiguous(ctx, false) // TODO contiguous may not be needed
|
||||
|
||||
// // slog.Info("XXX Causal.Put after reshape", "keyCache", keyCache)
|
||||
// // panic("XXX")
|
||||
// // curLoc := 0 // TODO c.curLoc is now a tensor
|
||||
// // kSize := numKVHeads * kHeadDim
|
||||
// // vSize := numKVHeads * vHeadDim
|
||||
// // start := []int{int(curLoc), 0}
|
||||
// // kStop := []int{int(curLoc + batchSize), int(kSize)}
|
||||
// // vStop := []int{int(curLoc + batchSize), int(vSize)}
|
||||
// // strides := []int{1, 1}
|
||||
|
||||
// // slog.Info("XXX Causal.Put Key SliceUpdate", "keyCache", keyCache)
|
||||
// // slog.Info("XXX Causal.Put Key SliceUpdate", "key", key)
|
||||
|
||||
// // slog.Info("XXX Causal.Put Key SliceUpdate", "start", start, "kStop", kStop, "strides", strides)
|
||||
|
||||
// // ctx.Forward(c.keys[c.curLayer].SliceUpdate(ctx, key, start, kStop, strides))
|
||||
// ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLoc}, key, []int{0}))
|
||||
// // fmt.Fprintln(os.Stderr, keyCache.ToString())
|
||||
// // panic("input value")
|
||||
|
||||
// // fmt.Fprintln(os.Stderr, t.ToString())
|
||||
// // panic("XXX")
|
||||
|
||||
// // if c.config.PermutedV {
|
||||
// // panic("permuted")
|
||||
// // // TODO not adjusted
|
||||
// // value = value.Reshape(ctx, vHeadDim*numKVHeads, 1, batchSize)
|
||||
// // value = value.Transpose(ctx, 2, 0, 1, 3)
|
||||
|
||||
// // valueCache := c.values[c.curLayer]
|
||||
// // valueCache = valueCache.Reshape(ctx, 1, len(c.cells), vHeadDim*numKVHeads)
|
||||
|
||||
// // ctx.Forward(valueCache.SliceUpdate(ctx, value, start, vStop, strides))
|
||||
// // } else {
|
||||
// value = value.Reshape(ctx, batchSize, 1, vCellSize) //.Contiguous(ctx, false) // TODO contiguous may not be needed
|
||||
// // slog.Info("XXX Causal.Put Value SliceUpdate", "valueCache", valueCache)
|
||||
// // slog.Info("XXX Causal.Put Value SliceUpdate", "value", value)
|
||||
// // slog.Info("XXX Causal.Put Value SliceUpdate", "start", start, "vStop", vStop, "strides", strides)
|
||||
|
||||
// ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLoc}, value, []int{0}))
|
||||
// // }
|
||||
// // fmt.Fprintln(os.Stderr, c.keys[c.curLayer].ToString())
|
||||
// // fmt.Fprintln(os.Stderr, c.values[c.curLayer].ToString())
|
||||
// // panic("XXX")
|
||||
|
||||
// }
|
||||
|
||||
// func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
// seqRange := newRange()
|
||||
|
||||
// for i := range c.cells {
|
||||
// // Remove the contents of dstSeq so that we only have the copied prefix, metadata will be reset at the end
|
||||
// if slices.Contains(c.cells[i].sequences, dstSeq) {
|
||||
// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == dstSeq })
|
||||
// }
|
||||
|
||||
// if slices.Contains(c.cells[i].sequences, srcSeq) && c.cells[i].pos < len {
|
||||
// c.cells[i].sequences = append(c.cells[i].sequences, dstSeq)
|
||||
// if i < seqRange.min {
|
||||
// seqRange.min = i
|
||||
// }
|
||||
// if i > seqRange.max {
|
||||
// seqRange.max = i
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// c.cellRanges[dstSeq] = seqRange
|
||||
// }
|
||||
|
||||
// func (c *Causal) CanResume(seq int, pos int32) bool {
|
||||
// if c.swaMemorySize == math.MaxInt32 {
|
||||
// return true
|
||||
// }
|
||||
|
||||
// seqRange, ok := c.cellRanges[seq]
|
||||
// if !ok {
|
||||
// return false
|
||||
// }
|
||||
|
||||
// // for sliding window, check that the window of the new sequence is contained in
|
||||
// // the window of what we are storing
|
||||
// var first int32 = math.MaxInt32
|
||||
// var last int32 = -1
|
||||
// for i := seqRange.min; i <= seqRange.max; i++ {
|
||||
// if slices.Contains(c.cells[i].sequences, seq) {
|
||||
// first = min(first, c.cells[i].pos)
|
||||
// last = max(last, c.cells[i].pos)
|
||||
// }
|
||||
// }
|
||||
|
||||
// if last == -1 {
|
||||
// return false
|
||||
// }
|
||||
|
||||
// posWindowStart := max(0, pos-c.swaWindowSize)
|
||||
// return posWindowStart >= first && pos <= last+1
|
||||
// }
|
||||
|
||||
// func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
||||
// if c.shiftFn == nil {
|
||||
// return ErrNotSupported
|
||||
// }
|
||||
|
||||
// seqRange := c.cellRanges[seq]
|
||||
|
||||
// for start := seqRange.min; start <= seqRange.max; start += c.maxBatch {
|
||||
// size := min(seqRange.max-start+1, c.maxBatch)
|
||||
// offsets := make([]int32, size)
|
||||
|
||||
// var batchFirst, batchLast int
|
||||
|
||||
// batchFirst = -1
|
||||
// for i := range offsets {
|
||||
// cell := c.cells[start+i]
|
||||
|
||||
// if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
|
||||
// offsets[i] = offset
|
||||
// if batchFirst < 0 {
|
||||
// batchFirst = i
|
||||
// }
|
||||
// batchLast = i
|
||||
// }
|
||||
// }
|
||||
|
||||
// if batchFirst < 0 {
|
||||
// continue
|
||||
// }
|
||||
|
||||
// offsets = offsets[batchFirst : batchLast+1]
|
||||
|
||||
// slog.Info("XXX Causal.shift creating new temporary context")
|
||||
// ctx := c.backend.NewContext()
|
||||
// kShift := ctx.Input().FromInts(offsets, len(offsets))
|
||||
|
||||
// for i, key := range c.keys {
|
||||
// if key == nil {
|
||||
// continue
|
||||
// }
|
||||
|
||||
// kHeadDim := key.Dim(2)
|
||||
// numKVHeads := key.Dim(1)
|
||||
// rowSize := key.Stride(0)
|
||||
|
||||
// key = key.AsStrided(ctx,
|
||||
// []int{len(offsets), numKVHeads, kHeadDim},
|
||||
// []int{key.Stride(0), key.Stride(1)},
|
||||
// rowSize*(start+batchFirst),
|
||||
// )
|
||||
|
||||
// roped, err := c.shiftFn(ctx, i, key, kShift)
|
||||
// if err != nil {
|
||||
// ctx.Close()
|
||||
// return err
|
||||
// }
|
||||
|
||||
// ctx.Forward(roped.Copy(ctx, key))
|
||||
// }
|
||||
|
||||
// ctx.Compute()
|
||||
// ctx.Close()
|
||||
// }
|
||||
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
// // TODO(jessegross): We should check to see if removing the middle of the sequence will
|
||||
// // cause the sliding window to encompass tokens that we no longer have. If so, then we
|
||||
// // should return an error, which will trigger the runner to evaluate the full history and
|
||||
// // rebuild the window. However, if we have multimodal inputs in our history, this reuse
|
||||
// // results in use after free, so we don't do it for now.
|
||||
|
||||
// var offset int32
|
||||
// if endIndex != math.MaxInt32 {
|
||||
// offset = beginIndex - endIndex
|
||||
// }
|
||||
|
||||
// seqRange := newRange()
|
||||
|
||||
// for i := range c.cells {
|
||||
// if slices.Contains(c.cells[i].sequences, seq) {
|
||||
// if c.cells[i].pos >= beginIndex && c.cells[i].pos < endIndex {
|
||||
// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
|
||||
// } else {
|
||||
// if c.cells[i].pos >= endIndex {
|
||||
// if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) {
|
||||
// return errors.New("shifting cells shared by multiple sequences not supported")
|
||||
// }
|
||||
|
||||
// c.cells[i].pos += offset
|
||||
// }
|
||||
// if i < seqRange.min {
|
||||
// seqRange.min = i
|
||||
// }
|
||||
// if i > seqRange.max {
|
||||
// seqRange.max = i
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// if seqRange == newRange() {
|
||||
// delete(c.cellRanges, seq)
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// c.cellRanges[seq] = seqRange
|
||||
|
||||
// if endIndex != math.MaxInt32 {
|
||||
// err := c.shift(seq, endIndex+offset, offset)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// }
|
||||
|
||||
// return nil
|
||||
// }
|
||||
import (
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
"github.com/ollama/ollama/x/model/input"
|
||||
)
|
||||
|
||||
// Causal cache stores K and V tensors according to their position in the
|
||||
// sequence. Returns the history and a mask for attending to past tokens
|
||||
type Causal struct {
|
||||
DType ml.DType
|
||||
|
||||
// locations for data storage for this batch
|
||||
curLocPut ml.Tensor
|
||||
|
||||
// locations for data storage for this batch
|
||||
curLocGet ml.Tensor
|
||||
|
||||
// the active layer for Get and Put
|
||||
curLayer int
|
||||
|
||||
capacity int
|
||||
|
||||
offset int
|
||||
|
||||
backend ml.Backend
|
||||
ctxs map[int]ml.Context
|
||||
keys, values map[int]ml.Tensor
|
||||
|
||||
// TODO is this needed per layer, or will it always be consistent?
|
||||
kHeadDims, vHeadDims, numKVHeads map[int]int
|
||||
}
|
||||
|
||||
func NewCausalCache() *Causal {
|
||||
return &Causal{
|
||||
ctxs: make(map[int]ml.Context),
|
||||
keys: make(map[int]ml.Tensor),
|
||||
values: make(map[int]ml.Tensor),
|
||||
kHeadDims: make(map[int]int),
|
||||
vHeadDims: make(map[int]int),
|
||||
numKVHeads: make(map[int]int),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
c.DType = dtype
|
||||
c.capacity = capacity
|
||||
c.backend = backend
|
||||
}
|
||||
|
||||
func (c *Causal) SetConfig(config ml.CacheConfig) {}
|
||||
|
||||
func (c *Causal) SetLayer(layer int) {
|
||||
c.curLayer = layer
|
||||
}
|
||||
|
||||
func (c *Causal) Close() {
|
||||
// slog.Info("XXX Causal.Close called", "number of contexts", len(c.ctxs))
|
||||
for _, ctx := range c.ctxs {
|
||||
ctx.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
locsPut := make([]int32, len(batch.Positions))
|
||||
for i := c.offset; i < len(batch.Positions); i++ {
|
||||
locsPut[i-c.offset] = int32(i)
|
||||
}
|
||||
c.offset += len(batch.Positions)
|
||||
locsGet := make([]int32, c.offset)
|
||||
for i := range c.offset {
|
||||
locsGet[i] = int32(i)
|
||||
}
|
||||
c.curLocGet = ctx.Input().FromInts(locsGet, len(locsGet))
|
||||
c.curLocPut = ctx.Input().FromInts(locsPut, len(locsPut))
|
||||
// slog.Info("XXX Causal.StartForward", "offset", c.offset, "put", locsPut, "get", locsGet)
|
||||
|
||||
return nil
|
||||
}
|
||||
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
kHeadDim := key.Dim(3)
|
||||
vHeadDim := value.Dim(3)
|
||||
numKVHeads := key.Dim(1)
|
||||
batchSize := key.Dim(2)
|
||||
kCellSize := kHeadDim * numKVHeads
|
||||
vCellSize := vHeadDim * numKVHeads
|
||||
// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize, "kCellSize", kCellSize, "vCellSize", vCellSize)
|
||||
|
||||
if _, ok := c.ctxs[c.curLayer]; !ok {
|
||||
// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
|
||||
c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
|
||||
}
|
||||
|
||||
if _, ok := c.keys[c.curLayer]; !ok {
|
||||
// slog.Info("XXX Causal.Put allocating keys and values", "c.curLayer", c.curLayer, "shape", []int{c.capacity, kCellSize})
|
||||
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, kCellSize)
|
||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, vCellSize)
|
||||
c.kHeadDims[c.curLayer] = kHeadDim
|
||||
c.vHeadDims[c.curLayer] = vHeadDim
|
||||
c.numKVHeads[c.curLayer] = numKVHeads
|
||||
}
|
||||
key = key.Reshape(ctx, batchSize, 1, kCellSize)
|
||||
|
||||
// slog.Info("XXX Causal.Put ", "c.keys[c.curLayer]", c.keys[c.curLayer])
|
||||
// slog.Info("XXX Causal.Put ", "c.curLocPut", c.curLocPut)
|
||||
// slog.Info("XXX Causal.Put ", "key", key)
|
||||
ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, key, []int{0}))
|
||||
value = value.Reshape(ctx, batchSize, 1, vCellSize)
|
||||
ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, value, []int{0}))
|
||||
|
||||
}
|
||||
|
||||
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
key := c.keys[c.curLayer]
|
||||
value := c.values[c.curLayer]
|
||||
|
||||
kHeadDim := c.kHeadDims[c.curLayer]
|
||||
vHeadDim := c.vHeadDims[c.curLayer]
|
||||
numKVHeads := c.numKVHeads[c.curLayer]
|
||||
// rowSize := numKVHeads * c.curBatchSize
|
||||
// cachedSize := c.curMask.Dim(1)
|
||||
cachedSize := c.curLocGet.Dim(0)
|
||||
// kCellSize := kHeadDim * numKVHeads
|
||||
// vCellSize := vHeadDim * numKVHeads
|
||||
// slog.Info("XXX Causal.Get", "shape", []int{1, numKVHeads, cachedSize, kHeadDim})
|
||||
|
||||
key = key.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
|
||||
value = value.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
|
||||
return key, value, nil
|
||||
}
|
||||
|
||||
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Causal) CanResume(seq int, pos int32) bool {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
@@ -1,973 +0,0 @@
|
||||
package kvcache
|
||||
|
||||
// import (
|
||||
// "fmt"
|
||||
// "math"
|
||||
// "slices"
|
||||
// "testing"
|
||||
|
||||
// "github.com/ollama/ollama/ml"
|
||||
// "github.com/ollama/ollama/model/input"
|
||||
// )
|
||||
|
||||
// type testCase struct {
|
||||
// name string
|
||||
// in []float32
|
||||
// inShape []int
|
||||
// seqs []int
|
||||
// pos []int32
|
||||
// expected []float32
|
||||
// expectedShape []int
|
||||
// expectedMask []float32
|
||||
// }
|
||||
|
||||
// func runPermutedVariants(t *testing.T, fn func(t *testing.T, backend *testBackend)) {
|
||||
// t.Helper()
|
||||
// for _, permuted := range []bool{false, true} {
|
||||
// t.Run(fmt.Sprintf("PermutedV=%t", permuted), func(t *testing.T) {
|
||||
// fn(t, &testBackend{permutedV: permuted})
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
|
||||
// func TestStore(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// cache := NewCausalCache(nil)
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// tests := []testCase{
|
||||
// {
|
||||
// name: "FirstBatch",
|
||||
// in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
|
||||
// inShape: []int{2, 3, 4},
|
||||
// seqs: []int{0, 0, 0, 0},
|
||||
// pos: []int32{0, 1, 2, 3},
|
||||
// expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
|
||||
// expectedShape: []int{2, 3, 4},
|
||||
// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
|
||||
// },
|
||||
// {
|
||||
// name: "SecondBatch",
|
||||
// in: []float32{115, 215, 125, 225, 135, 235},
|
||||
// inShape: []int{2, 3, 1},
|
||||
// seqs: []int{0},
|
||||
// pos: []int32{4},
|
||||
// expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235},
|
||||
// expectedShape: []int{2, 3, 5},
|
||||
// expectedMask: []float32{0, 0, 0, 0, 0},
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
// })
|
||||
// }
|
||||
|
||||
// func TestSWA(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// cache := NewSWACache(1, nil)
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// x := float32(math.Inf(-1))
|
||||
|
||||
// tests := []testCase{
|
||||
// {
|
||||
// name: "FirstBatch",
|
||||
// in: []float32{1, 2, 3, 4},
|
||||
// inShape: []int{1, 1, 4},
|
||||
// seqs: []int{0, 0, 0, 0},
|
||||
// pos: []int32{0, 1, 2, 3},
|
||||
// expected: []float32{1, 2, 3, 4},
|
||||
// expectedShape: []int{1, 1, 4},
|
||||
// expectedMask: []float32{
|
||||
// 0, x, x, x,
|
||||
// 0, 0, x, x,
|
||||
// x, 0, 0, x,
|
||||
// x, x, 0, 0,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "SecondBatch",
|
||||
// in: []float32{5, 6},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 0},
|
||||
// pos: []int32{4, 5},
|
||||
// expected: []float32{5, 6, 3, 4},
|
||||
// expectedShape: []int{1, 1, 4},
|
||||
// expectedMask: []float32{
|
||||
// 0, x, x, 0,
|
||||
// 0, 0, x, x,
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
// })
|
||||
// }
|
||||
|
||||
// func TestSWASeparateBatches(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// cache := NewSWACache(1, nil)
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 2, 16, 2)
|
||||
|
||||
// x := float32(math.Inf(-1))
|
||||
|
||||
// tests := []testCase{
|
||||
// {
|
||||
// name: "First seq 0",
|
||||
// in: []float32{1, 2},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 0},
|
||||
// pos: []int32{0, 1},
|
||||
// expected: []float32{1, 2},
|
||||
// expectedShape: []int{1, 1, 2},
|
||||
// expectedMask: []float32{
|
||||
// 0, x,
|
||||
// 0, 0,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "Second seq 0",
|
||||
// in: []float32{3, 4},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 0},
|
||||
// pos: []int32{2, 3},
|
||||
// expected: []float32{2, 3, 4},
|
||||
// expectedShape: []int{1, 1, 3},
|
||||
// expectedMask: []float32{
|
||||
// 0, 0, x,
|
||||
// x, 0, 0,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "First seq 1",
|
||||
// in: []float32{5, 6},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{1, 1},
|
||||
// pos: []int32{0, 1},
|
||||
// expected: []float32{5, 6},
|
||||
// expectedShape: []int{1, 1, 2},
|
||||
// expectedMask: []float32{
|
||||
// 0, x,
|
||||
// 0, 0,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "Second seq 1",
|
||||
// in: []float32{7, 8},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{1, 1},
|
||||
// pos: []int32{2, 3},
|
||||
// expected: []float32{6, 3, 4, 7, 8},
|
||||
// expectedShape: []int{1, 1, 5},
|
||||
// expectedMask: []float32{
|
||||
// 0, x, x, 0, x,
|
||||
// x, x, x, 0, 0,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "Third seq 0",
|
||||
// in: []float32{9, 10},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 0},
|
||||
// pos: []int32{4, 5},
|
||||
// expected: []float32{9, 10, 3, 4},
|
||||
// expectedShape: []int{1, 1, 4},
|
||||
// expectedMask: []float32{
|
||||
// 0, x, x, 0,
|
||||
// 0, 0, x, x,
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
// })
|
||||
// }
|
||||
|
||||
// func TestSWAMem(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// cache := NewSWAMemCache(1, 3, nil)
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// x := float32(math.Inf(-1))
|
||||
|
||||
// tests := []testCase{
|
||||
// {
|
||||
// name: "FirstBatch",
|
||||
// in: []float32{1, 2, 3, 4},
|
||||
// inShape: []int{1, 1, 4},
|
||||
// seqs: []int{0, 0, 0, 0},
|
||||
// pos: []int32{0, 1, 2, 3},
|
||||
// expected: []float32{1, 2, 3, 4},
|
||||
// expectedShape: []int{1, 1, 4},
|
||||
// expectedMask: []float32{
|
||||
// 0, x, x, x,
|
||||
// 0, 0, x, x,
|
||||
// x, 0, 0, x,
|
||||
// x, x, 0, 0,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "SecondBatch",
|
||||
// in: []float32{5, 6},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 0},
|
||||
// pos: []int32{4, 5},
|
||||
// expected: []float32{5, 2, 3, 4, 6},
|
||||
// expectedShape: []int{1, 1, 5},
|
||||
// expectedMask: []float32{
|
||||
// 0, x, x, 0, x,
|
||||
// 0, x, x, x, 0,
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
// })
|
||||
// }
|
||||
|
||||
// func TestChunkedAttention(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// cache := NewChunkedAttentionCache(2, nil)
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// x := float32(math.Inf(-1))
|
||||
|
||||
// testCache(
|
||||
// t, backend, cache,
|
||||
// []testCase{
|
||||
// {
|
||||
// name: "FirstBatch",
|
||||
// in: []float32{1, 2, 3, 4},
|
||||
// inShape: []int{1, 1, 4},
|
||||
// seqs: []int{0, 0, 0, 0},
|
||||
// pos: []int32{0, 1, 2, 3},
|
||||
// expected: []float32{1, 2, 3, 4},
|
||||
// expectedShape: []int{1, 1, 4},
|
||||
// expectedMask: []float32{
|
||||
// 0, x, x, x,
|
||||
// 0, 0, x, x,
|
||||
// x, x, 0, x,
|
||||
// x, x, 0, 0,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "SecondBatch",
|
||||
// in: []float32{5, 6, 7},
|
||||
// inShape: []int{1, 1, 3},
|
||||
// seqs: []int{0, 0, 0},
|
||||
// pos: []int32{4, 5, 6},
|
||||
// expected: []float32{1, 2, 3, 4, 5, 6, 7},
|
||||
// expectedShape: []int{1, 1, 7},
|
||||
// expectedMask: []float32{
|
||||
// x, x, x, x, 0, x, x,
|
||||
// x, x, x, x, 0, 0, x,
|
||||
// x, x, x, x, x, x, 0,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "ThirdBatch",
|
||||
// in: []float32{8, 9},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 0},
|
||||
// pos: []int32{7, 8},
|
||||
// expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9},
|
||||
// expectedShape: []int{1, 1, 9},
|
||||
// expectedMask: []float32{
|
||||
// x, x, x, x, x, x, 0, 0, x,
|
||||
// x, x, x, x, x, x, x, x, 0,
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// )
|
||||
// })
|
||||
// }
|
||||
|
||||
// func TestSequences(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// cache := NewCausalCache(nil)
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// tests := []testCase{
|
||||
// {
|
||||
// name: "FirstBatch",
|
||||
// in: []float32{1, 2, 3, 4},
|
||||
// inShape: []int{1, 1, 4},
|
||||
// seqs: []int{0, 0, 1, 1},
|
||||
// pos: []int32{0, 1, 0, 1},
|
||||
// expected: []float32{1, 2, 3, 4},
|
||||
// expectedShape: []int{1, 1, 4},
|
||||
// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||
// },
|
||||
// {
|
||||
// name: "SecondBatch",
|
||||
// in: []float32{5, 6},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 1},
|
||||
// pos: []int32{2, 2},
|
||||
// expected: []float32{1, 2, 3, 4, 5, 6},
|
||||
// expectedShape: []int{1, 1, 6},
|
||||
// expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
// })
|
||||
// }
|
||||
|
||||
// func TestRemove(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
// return key.Add(ctx, shift), nil
|
||||
// })
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// x := float32(math.Inf(-1))
|
||||
|
||||
// tests := []testCase{
|
||||
// {
|
||||
// name: "FirstBatch",
|
||||
// in: []float32{1, 2, 3, 4},
|
||||
// inShape: []int{1, 1, 4},
|
||||
// seqs: []int{0, 0, 1, 1},
|
||||
// pos: []int32{0, 1, 0, 1},
|
||||
// expected: []float32{1, 2, 3, 4},
|
||||
// expectedShape: []int{1, 1, 4},
|
||||
// expectedMask: []float32{
|
||||
// 0, x, x, x,
|
||||
// 0, 0, x, x,
|
||||
// x, x, 0, x,
|
||||
// x, x, 0, 0,
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
|
||||
// err := cache.Remove(0, 1, math.MaxInt32)
|
||||
// if err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
|
||||
// tests = []testCase{
|
||||
// {
|
||||
// name: "RemoveEnd",
|
||||
// in: []float32{5, 6},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 1},
|
||||
// pos: []int32{1, 2},
|
||||
// expected: []float32{1, 5, 3, 4, 6},
|
||||
// expectedShape: []int{1, 1, 5},
|
||||
// expectedMask: []float32{
|
||||
// 0, 0, x, x, x,
|
||||
// x, x, 0, 0, 0,
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
|
||||
// err = cache.Remove(0, 0, 1)
|
||||
// if err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
|
||||
// tests = []testCase{
|
||||
// {
|
||||
// name: "RemoveMiddle",
|
||||
// in: []float32{7, 8},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 0},
|
||||
// pos: []int32{1, 2},
|
||||
// expected: []float32{7, 4, 3, 4, 6, 8},
|
||||
// expectedShape: []int{1, 1, 6},
|
||||
// expectedMask: []float32{
|
||||
// 0, 0, x, x, x, x,
|
||||
// 0, 0, x, x, x, 0,
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
// })
|
||||
// }
|
||||
|
||||
// func TestCopy(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// tests := []testCase{
|
||||
// {
|
||||
// name: "FirstBatch",
|
||||
// in: []float32{1, 2, 3, 4},
|
||||
// inShape: []int{1, 1, 4},
|
||||
// seqs: []int{0, 0, 0, 0},
|
||||
// pos: []int32{0, 1, 2, 3},
|
||||
// expected: []float32{1, 2, 3, 4},
|
||||
// expectedShape: []int{1, 1, 4},
|
||||
// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
|
||||
// cache.CopyPrefix(0, 1, 2)
|
||||
|
||||
// tests = []testCase{
|
||||
// {
|
||||
// name: "Copy",
|
||||
// in: []float32{5, 6},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{1, 1},
|
||||
// pos: []int32{3, 4},
|
||||
// expected: []float32{1, 2, 3, 4, 5, 6},
|
||||
// expectedShape: []int{1, 1, 6},
|
||||
// expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
// })
|
||||
// }
|
||||
|
||||
// func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) {
|
||||
// for _, test := range tests {
|
||||
// t.Run(test.name, func(t *testing.T) {
|
||||
// context := backend.NewContext()
|
||||
// defer context.Close()
|
||||
|
||||
// err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}, false)
|
||||
// if err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
|
||||
// cache.SetLayer(0)
|
||||
// tensor := context.FromFloats(test.in, test.inShape...)
|
||||
// cache.Put(context, tensor, tensor)
|
||||
|
||||
// out, _, mask := cache.Get(context)
|
||||
|
||||
// context.Forward(out, mask).Compute(out, mask)
|
||||
|
||||
// if !slices.Equal(out.Floats(), test.expected) {
|
||||
// t.Errorf("TestCache: have %v; want %v", out.Floats(), test.expected)
|
||||
// }
|
||||
|
||||
// if !slices.Equal(out.Shape(), test.expectedShape) {
|
||||
// t.Errorf("TestCache: has shape %v; want %v", out.Shape(), test.expectedShape)
|
||||
// }
|
||||
|
||||
// if !slices.Equal(mask.Floats(), test.expectedMask) {
|
||||
// t.Errorf("TestCache: have mask: have %v want %v", mask.Floats(), test.expectedMask)
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
|
||||
// func TestCanResume(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// windowSize := int32(4)
|
||||
// cache := NewSWACache(windowSize, nil)
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// context := backend.NewContext()
|
||||
// defer context.Close()
|
||||
|
||||
// err := cache.StartForward(context, input.Batch{
|
||||
// Positions: []int32{0, 1, 2, 3, 4},
|
||||
// Sequences: []int{0, 0, 0, 0, 0},
|
||||
// }, false)
|
||||
// if err != nil {
|
||||
// t.Fatalf("StartForward failed: %v", err)
|
||||
// }
|
||||
|
||||
// cache.SetLayer(0)
|
||||
// tensor := context.FromFloats([]float32{1, 2, 3, 4, 5}, 1, 1, 5)
|
||||
// cache.Put(context, tensor, tensor)
|
||||
|
||||
// // with window size 4, nothing has slid out of the window yet
|
||||
// if !cache.CanResume(0, 0) {
|
||||
// t.Errorf("CanResume(0, 0) = false, want true (within window)")
|
||||
// }
|
||||
// if !cache.CanResume(0, 1) {
|
||||
// t.Errorf("CanResume(0, 1) = false, want true (within window)")
|
||||
// }
|
||||
// if !cache.CanResume(0, 2) {
|
||||
// t.Errorf("CanResume(0, 2) = false, want true (within window)")
|
||||
// }
|
||||
// if !cache.CanResume(0, 3) {
|
||||
// t.Errorf("CanResume(0, 3) = false, want true (latest position)")
|
||||
// }
|
||||
// if !cache.CanResume(0, 4) {
|
||||
// t.Errorf("CanResume(0, 4) = false, want true (latest position)")
|
||||
// }
|
||||
|
||||
// // shift window by adding position 5
|
||||
// err = cache.StartForward(context, input.Batch{
|
||||
// Positions: []int32{5},
|
||||
// Sequences: []int{0},
|
||||
// }, false)
|
||||
// if err != nil {
|
||||
// t.Fatalf("StartForward failed: %v", err)
|
||||
// }
|
||||
|
||||
// cache.SetLayer(0)
|
||||
// tensor = context.FromFloats([]float32{6}, 1, 1, 1)
|
||||
// cache.Put(context, tensor, tensor)
|
||||
|
||||
// // only the latest position has overlapping windows
|
||||
// if cache.CanResume(0, 0) {
|
||||
// t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 1) {
|
||||
// t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 2) {
|
||||
// t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 3) {
|
||||
// t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 4) {
|
||||
// t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
|
||||
// }
|
||||
// if !cache.CanResume(0, 5) {
|
||||
// t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)")
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
|
||||
// func TestCanResumeSWAMem(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// windowSize := int32(4)
|
||||
// memSize := int32(5)
|
||||
// cache := NewSWAMemCache(windowSize, memSize, nil)
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// context := backend.NewContext()
|
||||
// defer context.Close()
|
||||
|
||||
// err := cache.StartForward(context, input.Batch{
|
||||
// Positions: []int32{0, 1, 2, 3, 4, 5, 6},
|
||||
// Sequences: []int{0, 0, 0, 0, 0, 0, 0},
|
||||
// }, false)
|
||||
// if err != nil {
|
||||
// t.Fatalf("StartForward failed: %v", err)
|
||||
// }
|
||||
|
||||
// cache.SetLayer(0)
|
||||
// tensor := context.FromFloats([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7)
|
||||
// cache.Put(context, tensor, tensor)
|
||||
|
||||
// // shift window by adding position 7
|
||||
// err = cache.StartForward(context, input.Batch{
|
||||
// Positions: []int32{7},
|
||||
// Sequences: []int{0},
|
||||
// }, false)
|
||||
// if err != nil {
|
||||
// t.Fatalf("StartForward failed: %v", err)
|
||||
// }
|
||||
|
||||
// cache.SetLayer(0)
|
||||
// tensor = context.FromFloats([]float32{8}, 1, 1, 1)
|
||||
// cache.Put(context, tensor, tensor)
|
||||
|
||||
// // only the latest position has overlapping windows
|
||||
// if cache.CanResume(0, 0) {
|
||||
// t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 1) {
|
||||
// t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 2) {
|
||||
// t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 3) {
|
||||
// t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 4) {
|
||||
// t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 5) {
|
||||
// t.Errorf("after shift: CanResume(0, 5) = true, want false (outside window)")
|
||||
// }
|
||||
// if !cache.CanResume(0, 6) {
|
||||
// t.Errorf("after shift: CanResume(0, 6) = false, want true (inside window)")
|
||||
// }
|
||||
// if !cache.CanResume(0, 7) {
|
||||
// t.Errorf("after shift: CanResume(0, 7) = false, want true (latest position)")
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
|
||||
// type testBackend struct {
|
||||
// ml.Backend
|
||||
// permutedV bool
|
||||
// }
|
||||
|
||||
// func (b *testBackend) NewContext() ml.Context {
|
||||
// return &testContext{}
|
||||
// }
|
||||
|
||||
// func (b *testBackend) NewContextSize(int) ml.Context {
|
||||
// return &testContext{}
|
||||
// }
|
||||
|
||||
// func (b *testBackend) CacheConfig() ml.CacheConfig {
|
||||
// return ml.CacheConfig{PermutedV: b.permutedV}
|
||||
// }
|
||||
|
||||
// type testContext struct {
|
||||
// ml.Context
|
||||
// }
|
||||
|
||||
// func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
// total := 0
|
||||
|
||||
// if len(shape) > 0 {
|
||||
// total = 1
|
||||
// for _, s := range shape {
|
||||
// total *= s
|
||||
// }
|
||||
// }
|
||||
|
||||
// return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
|
||||
// }
|
||||
|
||||
// func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
// return c.Empty(dtype, shape...)
|
||||
// }
|
||||
|
||||
// func (c *testContext) FromFloats(s []float32, shape ...int) ml.Tensor {
|
||||
// t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
|
||||
|
||||
// copy(t.data, s)
|
||||
|
||||
// return t
|
||||
// }
|
||||
|
||||
// func (c *testContext) FromInts(s []int32, shape ...int) ml.Tensor {
|
||||
// f := make([]float32, len(s))
|
||||
// for i := range f {
|
||||
// f[i] = float32(s[i])
|
||||
// }
|
||||
|
||||
// out := c.FromFloats(f, shape...)
|
||||
// out.(*testTensor).dtype = ml.DTypeI32
|
||||
|
||||
// return out
|
||||
// }
|
||||
|
||||
// func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
|
||||
// s := make([]float32, 0, int((stop-start)/step))
|
||||
// for i := start; i < stop; i += step {
|
||||
// s = append(s, i)
|
||||
// }
|
||||
|
||||
// out := c.FromFloats(s, len(s))
|
||||
// out.(*testTensor).dtype = dtype
|
||||
// return out
|
||||
// }
|
||||
|
||||
// func (c *testContext) Input() ml.Context { return c }
|
||||
// func (c *testContext) Layer(int) ml.Context { return c }
|
||||
|
||||
// func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
|
||||
|
||||
// func (c *testContext) Compute(...ml.Tensor) {}
|
||||
|
||||
// func (c *testContext) Reserve() {}
|
||||
|
||||
// func (c *testContext) MaxGraphNodes() int {
|
||||
// return 10
|
||||
// }
|
||||
|
||||
// func (c *testContext) Close() {}
|
||||
|
||||
// type testTensor struct {
|
||||
// ml.Tensor
|
||||
|
||||
// dtype ml.DType
|
||||
// elementSize int
|
||||
// data []float32
|
||||
// shape []int
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Dim(n int) int {
|
||||
// return t.shape[n]
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Stride(n int) int {
|
||||
// stride := t.elementSize
|
||||
// for i := range n {
|
||||
// stride *= t.shape[i]
|
||||
// }
|
||||
|
||||
// return stride
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Shape() []int {
|
||||
// return t.shape
|
||||
// }
|
||||
|
||||
// func (t *testTensor) DType() ml.DType {
|
||||
// return t.dtype
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Floats() []float32 {
|
||||
// out := make([]float32, len(t.data))
|
||||
// copy(out, t.data)
|
||||
// return out
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Neg(ctx ml.Context) ml.Tensor {
|
||||
// out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
|
||||
// for i := range out.data {
|
||||
// out.data[i] = -t.data[i]
|
||||
// }
|
||||
// return out
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
// out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
|
||||
|
||||
// for i := range out.data {
|
||||
// out.data[i] = t.data[i] + t2.(*testTensor).data[i]
|
||||
// }
|
||||
|
||||
// return out
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
// return &testTensor{
|
||||
// dtype: t.dtype,
|
||||
// elementSize: t.elementSize,
|
||||
// data: t.data,
|
||||
// shape: shape,
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
// offset /= t.elementSize
|
||||
|
||||
// var s []int
|
||||
|
||||
// switch len(shape) {
|
||||
// case 1:
|
||||
// s = []int{shape[0]}
|
||||
// case 3:
|
||||
// s = []int{shape[0], shape[2]}
|
||||
// case 5:
|
||||
// s = []int{shape[0], shape[2], shape[4]}
|
||||
// default:
|
||||
// panic("unsupported number of dimensions")
|
||||
// }
|
||||
|
||||
// context := &testContext{}
|
||||
|
||||
// view := context.Empty(t.dtype, s...).(*testTensor)
|
||||
// view.data = t.data[offset : offset+len(view.data)]
|
||||
|
||||
// return view
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Permute(ctx ml.Context, order ...int) ml.Tensor {
|
||||
// if len(t.shape) > 4 || len(order) > 4 {
|
||||
// panic("permute only supports up to 4 dimensions")
|
||||
// }
|
||||
|
||||
// if len(order) != len(t.shape) && len(order) != 4 {
|
||||
// panic("invalid number of dimensions for permute")
|
||||
// }
|
||||
|
||||
// // ggml_permute expects 4 axes, so fill in any missing dimensions.
|
||||
// orderFull := append(make([]int, 0, 4), order...)
|
||||
// for len(orderFull) < 4 {
|
||||
// orderFull = append(orderFull, len(orderFull))
|
||||
// }
|
||||
|
||||
// seen := [4]bool{}
|
||||
|
||||
// shape4 := [4]int{1, 1, 1, 1}
|
||||
// for i := 0; i < len(t.shape) && i < 4; i++ {
|
||||
// shape4[i] = t.shape[i]
|
||||
// }
|
||||
|
||||
// newShape4 := [4]int{1, 1, 1, 1}
|
||||
// for axis := range 4 {
|
||||
// dst := orderFull[axis]
|
||||
// if dst < 0 || dst >= 4 {
|
||||
// panic("invalid axis for permute")
|
||||
// }
|
||||
// if seen[dst] {
|
||||
// panic("duplicate axis for permute")
|
||||
// }
|
||||
// seen[dst] = true
|
||||
// newShape4[dst] = shape4[axis]
|
||||
// }
|
||||
|
||||
// total := len(t.data)
|
||||
// newData := make([]float32, total)
|
||||
|
||||
// if total > 0 {
|
||||
// oldDims := shape4
|
||||
// newDims := newShape4
|
||||
|
||||
// oldStride := [4]int{1, 1, 1, 1}
|
||||
// newStride := [4]int{1, 1, 1, 1}
|
||||
// for i := 1; i < 4; i++ {
|
||||
// oldStride[i] = oldStride[i-1] * oldDims[i-1]
|
||||
// newStride[i] = newStride[i-1] * newDims[i-1]
|
||||
// }
|
||||
|
||||
// var coords [4]int
|
||||
// var newCoords [4]int
|
||||
|
||||
// for idx := range total {
|
||||
// remainder := idx
|
||||
// for axis := range 4 {
|
||||
// dim := oldDims[axis]
|
||||
// if dim == 0 {
|
||||
// coords[axis] = 0
|
||||
// continue
|
||||
// }
|
||||
// coords[axis] = remainder % dim
|
||||
// remainder /= dim
|
||||
// }
|
||||
|
||||
// for axis := range 4 {
|
||||
// newCoords[orderFull[axis]] = coords[axis]
|
||||
// }
|
||||
|
||||
// newIndex := 0
|
||||
// for axis := range 4 {
|
||||
// if newDims[axis] == 0 {
|
||||
// continue
|
||||
// }
|
||||
// newIndex += newCoords[axis] * newStride[axis]
|
||||
// }
|
||||
|
||||
// newData[newIndex] = t.data[idx]
|
||||
// }
|
||||
// }
|
||||
|
||||
// numDims := 4
|
||||
// for numDims > 1 && newShape4[numDims-1] <= 1 {
|
||||
// numDims--
|
||||
// }
|
||||
|
||||
// newShape := make([]int, numDims)
|
||||
// copy(newShape, newShape4[:numDims])
|
||||
|
||||
// return &testTensor{
|
||||
// dtype: t.dtype,
|
||||
// elementSize: t.elementSize,
|
||||
// data: newData,
|
||||
// shape: newShape,
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (t *testTensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor {
|
||||
// dst := t
|
||||
// srcTensor := src.(*testTensor)
|
||||
// idxTensor := idxs.(*testTensor)
|
||||
|
||||
// shapeTo4D := func(shape []int) [4]int {
|
||||
// out := [4]int{1, 1, 1, 1}
|
||||
// for i := 0; i < len(shape) && i < 4; i++ {
|
||||
// out[i] = shape[i]
|
||||
// }
|
||||
// return out
|
||||
// }
|
||||
|
||||
// computeStrides := func(shape [4]int) [4]int {
|
||||
// out := [4]int{1, 1, 1, 1}
|
||||
// for i := 1; i < 4; i++ {
|
||||
// out[i] = out[i-1] * shape[i-1]
|
||||
// }
|
||||
// return out
|
||||
// }
|
||||
|
||||
// dstShape4D := shapeTo4D(dst.shape)
|
||||
// srcShape4D := shapeTo4D(srcTensor.shape)
|
||||
// idxShape4D := shapeTo4D(idxTensor.shape)
|
||||
|
||||
// if dstShape4D[0] != srcShape4D[0] || dstShape4D[2] != srcShape4D[2] || dstShape4D[3] != srcShape4D[3] {
|
||||
// panic("SetRows requires matching tensor shapes")
|
||||
// }
|
||||
|
||||
// if srcShape4D[1] != idxShape4D[0] {
|
||||
// panic("SetRows rows/index mismatch")
|
||||
// }
|
||||
|
||||
// if srcShape4D[2]%idxShape4D[1] != 0 || srcShape4D[3]%idxShape4D[2] != 0 {
|
||||
// panic("SetRows cannot broadcast indices")
|
||||
// }
|
||||
|
||||
// if idxShape4D[3] != 1 {
|
||||
// panic("SetRows expects 1D or 2D index tensors")
|
||||
// }
|
||||
|
||||
// dstStride := computeStrides(dstShape4D)
|
||||
// srcStride := computeStrides(srcShape4D)
|
||||
// idxStride := computeStrides(idxShape4D)
|
||||
|
||||
// numColumns := srcShape4D[0]
|
||||
// numRows := srcShape4D[1]
|
||||
|
||||
// for dim3Index := range dstShape4D[3] {
|
||||
// for dim2Index := range dstShape4D[2] {
|
||||
// idxDim2 := 0
|
||||
// idxDim3 := 0
|
||||
// if idxShape4D[1] > 0 {
|
||||
// idxDim2 = dim2Index % idxShape4D[1]
|
||||
// }
|
||||
// if idxShape4D[2] > 0 {
|
||||
// idxDim3 = dim3Index % idxShape4D[2]
|
||||
// }
|
||||
|
||||
// idxBase := idxDim3*idxStride[2] + idxDim2*idxStride[1]
|
||||
// srcBase := dim3Index*srcStride[3] + dim2Index*srcStride[2]
|
||||
// dstBase := dim3Index*dstStride[3] + dim2Index*dstStride[2]
|
||||
|
||||
// for row := range numRows {
|
||||
// idx := int(idxTensor.data[idxBase+row*idxStride[0]])
|
||||
// if idx < 0 || idx >= dstShape4D[1] {
|
||||
// panic("SetRows index out of range")
|
||||
// }
|
||||
|
||||
// srcOffset := srcBase + row*srcStride[1]
|
||||
// dstOffset := dstBase + idx*dstStride[1]
|
||||
|
||||
// copy(dst.data[dstOffset:dstOffset+numColumns], srcTensor.data[srcOffset:srcOffset+numColumns])
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// return dst
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
// copy(t2.(*testTensor).data, t.data)
|
||||
// return nil
|
||||
// }
|
||||
144
x/kvcache/mlx.go
144
x/kvcache/mlx.go
@@ -1,144 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
"github.com/ollama/ollama/x/model/input"
|
||||
)
|
||||
|
||||
// Causal cache stores K and V tensors according to their position in the
|
||||
// sequence. Returns the history and a mask for attending to past tokens
|
||||
type MLXCausal struct {
|
||||
DType ml.DType
|
||||
|
||||
// locations for data storage for this batch
|
||||
curLocPut ml.Tensor
|
||||
|
||||
// locations for data storage for this batch
|
||||
curLocGet ml.Tensor
|
||||
|
||||
// the active layer for Get and Put
|
||||
curLayer int
|
||||
|
||||
capacity int
|
||||
|
||||
offset int
|
||||
|
||||
backend ml.Backend
|
||||
ctxs map[int]ml.Context
|
||||
keys, values map[int]ml.Tensor
|
||||
|
||||
// TODO is this needed per layer, or will it always be consistent?
|
||||
kHeadDims, vHeadDims, numKVHeads map[int]int
|
||||
}
|
||||
|
||||
func NewMLXCausalCache() *MLXCausal {
|
||||
return &MLXCausal{
|
||||
ctxs: make(map[int]ml.Context),
|
||||
keys: make(map[int]ml.Tensor),
|
||||
values: make(map[int]ml.Tensor),
|
||||
kHeadDims: make(map[int]int),
|
||||
vHeadDims: make(map[int]int),
|
||||
numKVHeads: make(map[int]int),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *MLXCausal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
c.DType = dtype
|
||||
c.capacity = capacity
|
||||
c.backend = backend
|
||||
}
|
||||
|
||||
func (c *MLXCausal) SetConfig(config ml.CacheConfig) {}
|
||||
|
||||
func (c *MLXCausal) SetLayer(layer int) {
|
||||
c.curLayer = layer
|
||||
}
|
||||
|
||||
func (c *MLXCausal) Close() {
|
||||
// slog.Info("XXX MLXCausal.Close called", "number of contexts", len(c.ctxs))
|
||||
for _, ctx := range c.ctxs {
|
||||
ctx.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *MLXCausal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
locsPut := make([]int32, len(batch.Positions))
|
||||
for i := c.offset; i < len(batch.Positions); i++ {
|
||||
locsPut[i-c.offset] = int32(i)
|
||||
}
|
||||
c.offset += len(batch.Positions)
|
||||
locsGet := make([]int32, c.offset)
|
||||
for i := range c.offset {
|
||||
locsGet[i] = int32(i)
|
||||
}
|
||||
c.curLocGet = ctx.Input().FromInts(locsGet, len(locsGet))
|
||||
c.curLocPut = ctx.Input().FromInts(locsPut, len(locsPut))
|
||||
// slog.Info("XXX MLXCausal.StartForward", "offset", c.offset, "put", locsPut, "get", locsGet)
|
||||
|
||||
return nil
|
||||
}
|
||||
func (c *MLXCausal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
kHeadDim := key.Dim(3)
|
||||
vHeadDim := value.Dim(3)
|
||||
numKVHeads := key.Dim(1)
|
||||
batchSize := key.Dim(2)
|
||||
kCellSize := kHeadDim * numKVHeads
|
||||
vCellSize := vHeadDim * numKVHeads
|
||||
// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize, "kCellSize", kCellSize, "vCellSize", vCellSize)
|
||||
|
||||
if _, ok := c.ctxs[c.curLayer]; !ok {
|
||||
// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
|
||||
c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
|
||||
}
|
||||
|
||||
if _, ok := c.keys[c.curLayer]; !ok {
|
||||
// slog.Info("XXX MLXCausal.Put allocating keys and values", "c.curLayer", c.curLayer, "shape", []int{c.capacity, kCellSize})
|
||||
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, kCellSize)
|
||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, vCellSize)
|
||||
c.kHeadDims[c.curLayer] = kHeadDim
|
||||
c.vHeadDims[c.curLayer] = vHeadDim
|
||||
c.numKVHeads[c.curLayer] = numKVHeads
|
||||
}
|
||||
key = key.Reshape(ctx, batchSize, 1, kCellSize)
|
||||
|
||||
// slog.Info("XXX MLXCausal.Put ", "c.keys[c.curLayer]", c.keys[c.curLayer])
|
||||
// slog.Info("XXX MLXCausal.Put ", "c.curLocPut", c.curLocPut)
|
||||
// slog.Info("XXX MLXCausal.Put ", "key", key)
|
||||
ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, key, []int{0}))
|
||||
value = value.Reshape(ctx, batchSize, 1, vCellSize)
|
||||
ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, value, []int{0}))
|
||||
|
||||
}
|
||||
|
||||
func (c *MLXCausal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
key := c.keys[c.curLayer]
|
||||
value := c.values[c.curLayer]
|
||||
|
||||
kHeadDim := c.kHeadDims[c.curLayer]
|
||||
vHeadDim := c.vHeadDims[c.curLayer]
|
||||
numKVHeads := c.numKVHeads[c.curLayer]
|
||||
// rowSize := numKVHeads * c.curBatchSize
|
||||
// cachedSize := c.curMask.Dim(1)
|
||||
cachedSize := c.curLocGet.Dim(0)
|
||||
// kCellSize := kHeadDim * numKVHeads
|
||||
// vCellSize := vHeadDim * numKVHeads
|
||||
// slog.Info("XXX MLXCausal.Get", "shape", []int{1, numKVHeads, cachedSize, kHeadDim})
|
||||
|
||||
key = key.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
|
||||
value = value.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
|
||||
return key, value, nil
|
||||
}
|
||||
|
||||
func (c *MLXCausal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *MLXCausal) CanResume(seq int, pos int32) bool {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *MLXCausal) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
134
x/mlxrunner/imagegen.go
Normal file
134
x/mlxrunner/imagegen.go
Normal file
@@ -0,0 +1,134 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/flux2"
|
||||
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
||||
)
|
||||
|
||||
// ImageModel is the interface for image generation models.
|
||||
type ImageModel interface {
|
||||
GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error)
|
||||
}
|
||||
|
||||
var imageGenMu sync.Mutex
|
||||
|
||||
// loadImageModel loads an image generation model.
|
||||
func (s *server) loadImageModel() error {
|
||||
// Check memory requirements before loading
|
||||
var requiredMemory uint64
|
||||
if manifest, err := imagegen.LoadManifest(s.modelName); err == nil {
|
||||
requiredMemory = uint64(manifest.TotalTensorSize())
|
||||
}
|
||||
availableMemory := mlx.GetMemoryLimit()
|
||||
if availableMemory > 0 && requiredMemory > 0 && availableMemory < requiredMemory {
|
||||
return fmt.Errorf("insufficient memory for image generation: need %d GB, have %d GB",
|
||||
requiredMemory/(1024*1024*1024), availableMemory/(1024*1024*1024))
|
||||
}
|
||||
|
||||
// Detect model type and load appropriate model
|
||||
modelType := imagegen.DetectModelType(s.modelName)
|
||||
slog.Info("detected image model type", "type", modelType)
|
||||
|
||||
var model ImageModel
|
||||
switch modelType {
|
||||
case "Flux2KleinPipeline":
|
||||
m := &flux2.Model{}
|
||||
if err := m.Load(s.modelName); err != nil {
|
||||
return fmt.Errorf("failed to load flux2 model: %w", err)
|
||||
}
|
||||
model = m
|
||||
default:
|
||||
// Default to Z-Image for ZImagePipeline, FluxPipeline, etc.
|
||||
m := &zimage.Model{}
|
||||
if err := m.Load(s.modelName); err != nil {
|
||||
return fmt.Errorf("failed to load zimage model: %w", err)
|
||||
}
|
||||
model = m
|
||||
}
|
||||
|
||||
s.imageModel = model
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleImageCompletion handles image generation requests.
|
||||
func (s *server) handleImageCompletion(w http.ResponseWriter, r *http.Request, req Request) {
|
||||
// Serialize generation requests - MLX model may not handle concurrent generation
|
||||
imageGenMu.Lock()
|
||||
defer imageGenMu.Unlock()
|
||||
|
||||
// Set seed if not provided
|
||||
if req.Seed <= 0 {
|
||||
req.Seed = time.Now().UnixNano()
|
||||
}
|
||||
|
||||
// Set up streaming response
|
||||
w.Header().Set("Content-Type", "application/x-ndjson")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "streaming not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
enc := json.NewEncoder(w)
|
||||
|
||||
// Progress callback streams step updates
|
||||
progress := func(step, total int) {
|
||||
resp := Response{Step: step, Total: total}
|
||||
enc.Encode(resp)
|
||||
w.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
// Generate image
|
||||
img, err := s.imageModel.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress)
|
||||
if err != nil {
|
||||
// Don't send error for cancellation
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
resp := Response{Content: fmt.Sprintf("error: %v", err), Done: true}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(data)
|
||||
w.Write([]byte("\n"))
|
||||
return
|
||||
}
|
||||
|
||||
// Encode image as base64 PNG
|
||||
imageData, err := imagegen.EncodeImageBase64(img)
|
||||
if err != nil {
|
||||
resp := Response{Content: fmt.Sprintf("error encoding: %v", err), Done: true}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(data)
|
||||
w.Write([]byte("\n"))
|
||||
return
|
||||
}
|
||||
|
||||
// Free the generated image array and clean up MLX state
|
||||
img.Free()
|
||||
mlx.ClearCache()
|
||||
mlx.MetalResetPeakMemory()
|
||||
|
||||
// Send final response with image data
|
||||
resp := Response{
|
||||
Image: imageData,
|
||||
Done: true,
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(data)
|
||||
w.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
}
|
||||
420
x/mlxrunner/llm.go
Normal file
420
x/mlxrunner/llm.go
Normal file
@@ -0,0 +1,420 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/glm4_moe_lite"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// TextModel is the interface for LLM text generation models.
|
||||
type TextModel interface {
|
||||
Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array
|
||||
NewCache(maxSeqLen int32) []cache.Cache
|
||||
Tokenizer() *tokenizer.Tokenizer
|
||||
VocabSize() int32
|
||||
MaxContextLength() int32
|
||||
NumLayers() int
|
||||
}
|
||||
|
||||
// llmState holds the state for LLM generation
|
||||
type llmState struct {
|
||||
model TextModel
|
||||
}
|
||||
|
||||
var llmMu sync.Mutex
|
||||
|
||||
// Dedicated stream for generation (like mlx-lm's generation_stream)
|
||||
var generationStream *mlx.Stream
|
||||
|
||||
// withStream runs fn with the generation stream as default
|
||||
func withStream(fn func()) {
|
||||
// Lazy initialization of generationStream
|
||||
if generationStream == nil {
|
||||
generationStream = mlx.NewStream()
|
||||
}
|
||||
orig := mlx.GetDefaultStream()
|
||||
mlx.SetDefaultStream(generationStream)
|
||||
fn()
|
||||
mlx.SetDefaultStream(orig)
|
||||
}
|
||||
|
||||
// Decoder wraps model + cache for autoregressive generation.
|
||||
// This matches the pattern from cmd/engine/generate.go
|
||||
type Decoder struct {
|
||||
model TextModel
|
||||
caches []cache.Cache
|
||||
vocabSize int32
|
||||
temp float32
|
||||
token *mlx.Array // Current token (kept across iterations)
|
||||
oldCacheState []*mlx.Array // Preallocated slice for old cache state
|
||||
}
|
||||
|
||||
func NewDecoder(m TextModel, temp float32) *Decoder {
|
||||
caches := m.NewCache(0)
|
||||
return &Decoder{
|
||||
model: m,
|
||||
caches: caches,
|
||||
vocabSize: m.VocabSize(),
|
||||
temp: temp,
|
||||
oldCacheState: make([]*mlx.Array, 0, len(caches)*2),
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Decoder) prefill(inputIDs []int32) int {
|
||||
processed := 0
|
||||
|
||||
// Track old cache state to free after each chunk
|
||||
var oldCacheState []*mlx.Array
|
||||
|
||||
// Process all-but-1 tokens in chunks, eval cache state for memory management
|
||||
for len(inputIDs) > 1 {
|
||||
chunkSize := min(2048, len(inputIDs)-1)
|
||||
if chunkSize <= 0 {
|
||||
break
|
||||
}
|
||||
chunk := inputIDs[:chunkSize]
|
||||
|
||||
// Save old cache state before forward
|
||||
oldCacheState = oldCacheState[:0]
|
||||
for _, c := range d.caches {
|
||||
oldCacheState = append(oldCacheState, c.State()...)
|
||||
}
|
||||
|
||||
var cacheState []*mlx.Array
|
||||
withStream(func() {
|
||||
x := mlx.NewArrayInt32(chunk, []int32{1, int32(len(chunk))})
|
||||
d.model.Forward(x, d.caches)
|
||||
for _, c := range d.caches {
|
||||
cacheState = append(cacheState, c.State()...)
|
||||
}
|
||||
})
|
||||
mlx.Eval(cacheState...)
|
||||
|
||||
// Free old cache state
|
||||
for _, arr := range oldCacheState {
|
||||
if arr != nil {
|
||||
arr.Free()
|
||||
}
|
||||
}
|
||||
|
||||
inputIDs = inputIDs[chunkSize:]
|
||||
processed += chunkSize
|
||||
}
|
||||
|
||||
// Save old cache state before final step
|
||||
oldCacheState = oldCacheState[:0]
|
||||
for _, c := range d.caches {
|
||||
oldCacheState = append(oldCacheState, c.State()...)
|
||||
}
|
||||
|
||||
// Final token + sampling
|
||||
withStream(func() {
|
||||
x := mlx.NewArrayInt32(inputIDs, []int32{1, int32(len(inputIDs))})
|
||||
mlx.Eval(x) // Materialize before any other evals
|
||||
logits := d.model.Forward(x, d.caches)
|
||||
d.token = sample(logits, d.temp, d.vocabSize)
|
||||
})
|
||||
// Keep cache state (token auto-kept by AsyncEval)
|
||||
for _, c := range d.caches {
|
||||
mlx.Keep(c.State()...)
|
||||
}
|
||||
mlx.AsyncEval(d.token)
|
||||
|
||||
// Free old cache state from before final step
|
||||
for _, arr := range oldCacheState {
|
||||
if arr != nil {
|
||||
arr.Free()
|
||||
}
|
||||
}
|
||||
|
||||
mlx.ClearCache()
|
||||
|
||||
return processed + len(inputIDs)
|
||||
}
|
||||
|
||||
func (d *Decoder) step() int32 {
|
||||
prevToken := d.token
|
||||
|
||||
// Save old cache state (reuse preallocated slice)
|
||||
d.oldCacheState = d.oldCacheState[:0]
|
||||
for _, c := range d.caches {
|
||||
d.oldCacheState = append(d.oldCacheState, c.State()...)
|
||||
}
|
||||
|
||||
withStream(func() {
|
||||
logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches)
|
||||
d.token = sample(logits, d.temp, d.vocabSize)
|
||||
})
|
||||
// Keep token and new cache state so they survive cleanup
|
||||
mlx.Keep(d.token)
|
||||
for _, c := range d.caches {
|
||||
mlx.Keep(c.State()...)
|
||||
}
|
||||
mlx.AsyncEval(d.token)
|
||||
|
||||
// Sync on previous token (GPU already working on next step)
|
||||
val := prevToken.ItemInt32()
|
||||
|
||||
// Free old token and old cache state
|
||||
prevToken.Free()
|
||||
for _, arr := range d.oldCacheState {
|
||||
arr.Free()
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
// sample samples from logits using temperature scaling
|
||||
func sample(logits *mlx.Array, temp float32, vocabSize int32) *mlx.Array {
|
||||
// Get last position logits: [1, L, vocab] -> [vocab]
|
||||
shape := logits.Shape()
|
||||
seqLen := shape[1]
|
||||
lastLogits := mlx.Slice(logits, []int32{0, seqLen - 1, 0}, []int32{1, seqLen, vocabSize})
|
||||
lastLogits = mlx.Reshape(lastLogits, vocabSize)
|
||||
|
||||
if temp <= 0 || temp < 0.01 {
|
||||
// Greedy decoding
|
||||
return mlx.Argmax(lastLogits, -1, false)
|
||||
}
|
||||
|
||||
// Apply temperature scaling
|
||||
scaled := mlx.DivScalar(lastLogits, temp)
|
||||
return mlx.RandomCategorical(scaled, -1, 1)
|
||||
}
|
||||
|
||||
// loadLLMModel loads a safetensors LLM model and its tokenizer from manifest storage.
|
||||
func (s *server) loadLLMModel() error {
|
||||
// Load the manifest to get model information
|
||||
manifest, err := imagegen.LoadManifest(s.modelName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
// Detect model architecture from config.json
|
||||
configData, err := manifest.ReadConfig("config.json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read config.json: %w", err)
|
||||
}
|
||||
|
||||
var modelConfig struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(configData, &modelConfig); err != nil {
|
||||
return fmt.Errorf("failed to parse config.json: %w", err)
|
||||
}
|
||||
|
||||
arch := ""
|
||||
if len(modelConfig.Architectures) > 0 {
|
||||
arch = modelConfig.Architectures[0]
|
||||
}
|
||||
if arch == "" {
|
||||
arch = modelConfig.ModelType
|
||||
}
|
||||
|
||||
slog.Info("detected LLM architecture", "architecture", arch, "model_type", modelConfig.ModelType)
|
||||
|
||||
// Load the appropriate model based on architecture
|
||||
var model TextModel
|
||||
archLower := strings.ToLower(arch)
|
||||
|
||||
switch {
|
||||
case strings.Contains(archLower, "glm4moelite"):
|
||||
m, err := glm4_moe_lite.LoadFromManifest(manifest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load glm4-moe-lite model: %w", err)
|
||||
}
|
||||
model = m
|
||||
slog.Info("loaded glm4-moe-lite model", "vocab_size", m.VocabSize(), "layers", m.NumLayers())
|
||||
|
||||
default:
|
||||
return fmt.Errorf("LLM architecture %q is not yet supported. "+
|
||||
"Supported architectures: glm4-moe-lite. "+
|
||||
"Please convert your model to GGUF format or use a supported architecture", arch)
|
||||
}
|
||||
|
||||
s.llmModel = &llmState{
|
||||
model: model,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleLLMCompletion handles LLM text generation requests.
|
||||
func (s *server) handleLLMCompletion(w http.ResponseWriter, r *http.Request, req Request) {
|
||||
if s.llmModel == nil {
|
||||
http.Error(w, "LLM model not loaded", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Serialize generation requests
|
||||
llmMu.Lock()
|
||||
defer llmMu.Unlock()
|
||||
|
||||
if err := s.llmGenerate(w, r, req); err != nil {
|
||||
slog.Error("LLM generation failed", "error", err)
|
||||
// Don't send error if we've already started streaming
|
||||
}
|
||||
}
|
||||
|
||||
// llmGenerate runs the generation loop using the Decoder pattern from cmd/engine
|
||||
func (s *server) llmGenerate(w http.ResponseWriter, r *http.Request, req Request) error {
|
||||
state := s.llmModel
|
||||
|
||||
// Set up streaming response
|
||||
w.Header().Set("Content-Type", "application/x-ndjson")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
return errors.New("streaming not supported")
|
||||
}
|
||||
|
||||
tok := state.model.Tokenizer()
|
||||
|
||||
// The prompt is already formatted by the server using the model's renderer
|
||||
// (see server/prompt.go renderPrompt), so we don't apply FormatPrompt here.
|
||||
prompt := req.Prompt
|
||||
|
||||
// Tokenize the prompt
|
||||
inputIDs := tok.Encode(prompt, true)
|
||||
slog.Debug("tokenized prompt", "num_tokens", len(inputIDs))
|
||||
|
||||
// Generation parameters
|
||||
maxTokens := int(state.model.MaxContextLength())
|
||||
if maxTokens <= 0 {
|
||||
maxTokens = 4096
|
||||
}
|
||||
if req.Options != nil && req.Options.NumPredict > 0 {
|
||||
maxTokens = req.Options.NumPredict
|
||||
}
|
||||
|
||||
temperature := float32(0.7)
|
||||
if req.Options != nil && req.Options.Temperature > 0 {
|
||||
temperature = float32(req.Options.Temperature)
|
||||
}
|
||||
|
||||
// Enable MLX compilation for better performance
|
||||
mlx.EnableCompile()
|
||||
|
||||
// Create decoder with fresh caches
|
||||
dec := NewDecoder(state.model, temperature)
|
||||
|
||||
prefillStart := time.Now()
|
||||
prefillTokens := dec.prefill(inputIDs)
|
||||
// Prefill measurement includes time to first token
|
||||
firstToken := dec.step()
|
||||
prefillDuration := time.Since(prefillStart)
|
||||
promptEvalDuration := prefillDuration
|
||||
|
||||
enc := json.NewEncoder(w)
|
||||
ctx := r.Context()
|
||||
generated := 0
|
||||
stopReason := "max_tokens"
|
||||
|
||||
// Handle first token
|
||||
generated++
|
||||
if tok.IsEOS(firstToken) {
|
||||
resp := Response{
|
||||
Done: true,
|
||||
StopReason: fmt.Sprintf("first_token_eos:%d", firstToken),
|
||||
PromptEvalCount: prefillTokens,
|
||||
PromptEvalDuration: int(promptEvalDuration.Nanoseconds()),
|
||||
}
|
||||
enc.Encode(resp)
|
||||
flusher.Flush()
|
||||
return nil
|
||||
}
|
||||
|
||||
text := tok.Decode([]int32{firstToken})
|
||||
resp := Response{Content: text}
|
||||
enc.Encode(resp)
|
||||
flusher.Flush()
|
||||
|
||||
genStart := time.Now()
|
||||
|
||||
// Generation loop
|
||||
for n := 1; n < maxTokens; n++ {
|
||||
// Check for cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
stopReason = fmt.Sprintf("context_cancelled:%d", generated)
|
||||
break
|
||||
default:
|
||||
}
|
||||
if stopReason != "max_tokens" {
|
||||
break
|
||||
}
|
||||
|
||||
token := dec.step()
|
||||
generated++
|
||||
|
||||
if tok.IsEOS(token) {
|
||||
stopReason = fmt.Sprintf("eos_token:%d", token)
|
||||
break
|
||||
}
|
||||
|
||||
text := tok.Decode([]int32{token})
|
||||
|
||||
// Check for stop sequences
|
||||
if req.Options != nil && len(req.Options.Stop) > 0 {
|
||||
shouldStop := false
|
||||
var matchedStop string
|
||||
for _, stop := range req.Options.Stop {
|
||||
if strings.Contains(text, stop) {
|
||||
text = strings.Split(text, stop)[0]
|
||||
shouldStop = true
|
||||
matchedStop = stop
|
||||
break
|
||||
}
|
||||
}
|
||||
if shouldStop {
|
||||
if text != "" {
|
||||
resp := Response{Content: text}
|
||||
enc.Encode(resp)
|
||||
flusher.Flush()
|
||||
}
|
||||
stopReason = fmt.Sprintf("stop_sequence:%s", matchedStop)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
resp := Response{Content: text}
|
||||
enc.Encode(resp)
|
||||
flusher.Flush()
|
||||
|
||||
// Periodically clear MLX cache
|
||||
if n%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up
|
||||
mlx.ClearCache()
|
||||
|
||||
// Send final response with stats
|
||||
evalDuration := time.Since(genStart)
|
||||
resp = Response{
|
||||
Done: true,
|
||||
StopReason: fmt.Sprintf("%s:generated=%d", stopReason, generated),
|
||||
PromptEvalCount: prefillTokens,
|
||||
PromptEvalDuration: int(promptEvalDuration.Nanoseconds()),
|
||||
EvalCount: generated,
|
||||
EvalDuration: int(evalDuration.Nanoseconds()),
|
||||
}
|
||||
enc.Encode(resp)
|
||||
flusher.Flush()
|
||||
|
||||
return nil
|
||||
}
|
||||
204
x/mlxrunner/runner.go
Normal file
204
x/mlxrunner/runner.go
Normal file
@@ -0,0 +1,204 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package mlxrunner provides a unified MLX runner for both LLM and image generation models.
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// Execute is the entry point for the unified MLX runner subprocess.
|
||||
func Execute(args []string) error {
|
||||
// Set up logging with appropriate level from environment
|
||||
slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: envconfig.LogLevel()})))
|
||||
|
||||
fs := flag.NewFlagSet("mlx-runner", flag.ExitOnError)
|
||||
modelName := fs.String("model", "", "path to model")
|
||||
port := fs.Int("port", 0, "port to listen on")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if *modelName == "" {
|
||||
return fmt.Errorf("--model is required")
|
||||
}
|
||||
if *port == 0 {
|
||||
return fmt.Errorf("--port is required")
|
||||
}
|
||||
|
||||
// Initialize MLX
|
||||
if err := mlx.InitMLX(); err != nil {
|
||||
slog.Error("unable to initialize MLX", "error", err)
|
||||
return err
|
||||
}
|
||||
slog.Info("MLX library initialized")
|
||||
|
||||
// Detect model type from capabilities
|
||||
mode := detectModelMode(*modelName)
|
||||
slog.Info("starting mlx runner", "model", *modelName, "port", *port, "mode", mode)
|
||||
|
||||
// Create and start server
|
||||
server, err := newServer(*modelName, *port, mode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create server: %w", err)
|
||||
}
|
||||
|
||||
// Set up HTTP handlers
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/health", server.healthHandler)
|
||||
mux.HandleFunc("/completion", server.completionHandler)
|
||||
|
||||
// LLM-specific endpoints
|
||||
if mode == ModeLLM {
|
||||
mux.HandleFunc("/tokenize", server.tokenizeHandler)
|
||||
mux.HandleFunc("/embedding", server.embeddingHandler)
|
||||
}
|
||||
|
||||
httpServer := &http.Server{
|
||||
Addr: fmt.Sprintf("127.0.0.1:%d", *port),
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
// Handle shutdown
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-sigCh
|
||||
slog.Info("shutting down mlx runner")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
httpServer.Shutdown(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
slog.Info("mlx runner listening", "addr", httpServer.Addr)
|
||||
if err := httpServer.ListenAndServe(); err != http.ErrServerClosed {
|
||||
return err
|
||||
}
|
||||
|
||||
<-done
|
||||
return nil
|
||||
}
|
||||
|
||||
// detectModelMode determines whether a model is an LLM or image generation model.
|
||||
func detectModelMode(modelName string) ModelMode {
|
||||
// Check for image generation model by looking at model_index.json
|
||||
modelType := imagegen.DetectModelType(modelName)
|
||||
if modelType != "" {
|
||||
// Known image generation model types
|
||||
switch modelType {
|
||||
case "ZImagePipeline", "FluxPipeline", "Flux2KleinPipeline":
|
||||
return ModeImageGen
|
||||
}
|
||||
}
|
||||
|
||||
// Default to LLM mode for safetensors models without known image gen types
|
||||
return ModeLLM
|
||||
}
|
||||
|
||||
// server holds the model and handles HTTP requests.
|
||||
type server struct {
|
||||
mode ModelMode
|
||||
modelName string
|
||||
port int
|
||||
|
||||
// Image generation model (when mode == ModeImageGen)
|
||||
imageModel ImageModel
|
||||
|
||||
// LLM model (when mode == ModeLLM)
|
||||
llmModel *llmState
|
||||
}
|
||||
|
||||
// newServer creates a new server instance and loads the appropriate model.
|
||||
func newServer(modelName string, port int, mode ModelMode) (*server, error) {
|
||||
s := &server{
|
||||
mode: mode,
|
||||
modelName: modelName,
|
||||
port: port,
|
||||
}
|
||||
|
||||
switch mode {
|
||||
case ModeImageGen:
|
||||
if err := s.loadImageModel(); err != nil {
|
||||
return nil, fmt.Errorf("failed to load image model: %w", err)
|
||||
}
|
||||
case ModeLLM:
|
||||
if err := s.loadLLMModel(); err != nil {
|
||||
return nil, fmt.Errorf("failed to load LLM model: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *server) healthHandler(w http.ResponseWriter, r *http.Request) {
|
||||
resp := HealthResponse{Status: "ok"}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
func (s *server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req Request
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
switch s.mode {
|
||||
case ModeImageGen:
|
||||
s.handleImageCompletion(w, r, req)
|
||||
case ModeLLM:
|
||||
s.handleLLMCompletion(w, r, req)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *server) tokenizeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if s.llmModel == nil {
|
||||
http.Error(w, "LLM model not loaded", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
tok := s.llmModel.model.Tokenizer()
|
||||
tokens := tok.Encode(req.Content, false)
|
||||
|
||||
// Convert int32 to int for JSON response
|
||||
intTokens := make([]int, len(tokens))
|
||||
for i, t := range tokens {
|
||||
intTokens[i] = int(t)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string][]int{"tokens": intTokens})
|
||||
}
|
||||
|
||||
func (s *server) embeddingHandler(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "embeddings not yet implemented for MLX models", http.StatusNotImplemented)
|
||||
}
|
||||
@@ -1,10 +1,10 @@
|
||||
//go:build !mlx
|
||||
|
||||
package runner
|
||||
package mlxrunner
|
||||
|
||||
import "errors"
|
||||
|
||||
// Execute returns an error when not built with MLX support.
|
||||
func Execute(args []string) error {
|
||||
return errors.New("image generation not available: build with mlx tag")
|
||||
return errors.New("MLX runner not available: build with mlx tag")
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package imagegen
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math/rand"
|
||||
"net"
|
||||
@@ -22,19 +23,19 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
// Server wraps an image generation subprocess to implement llm.LlamaServer.
|
||||
// Server wraps an MLX runner subprocess to implement llm.LlamaServer.
|
||||
//
|
||||
// This implementation is compatible with Ollama's scheduler and can be loaded/unloaded
|
||||
// like any other model. The plan is to eventually bring this into the llm/ package
|
||||
// and evolve llm/ to support MLX and multimodal models. For now, keeping the code
|
||||
// separate allows for independent iteration on image generation support.
|
||||
// like any other model. It supports both LLM (safetensors) and image generation models.
|
||||
type Server struct {
|
||||
mu sync.Mutex
|
||||
cmd *exec.Cmd
|
||||
port int
|
||||
modelName string
|
||||
mode ModelMode
|
||||
vramSize uint64
|
||||
done chan error
|
||||
client *http.Client
|
||||
@@ -42,10 +43,10 @@ type Server struct {
|
||||
lastErrLock sync.Mutex
|
||||
}
|
||||
|
||||
// NewServer spawns a new image generation subprocess and waits until it's ready.
|
||||
func NewServer(modelName string) (*Server, error) {
|
||||
// NewServer spawns a new MLX runner subprocess and waits until it's ready.
|
||||
func NewServer(modelName string, mode ModelMode) (*Server, error) {
|
||||
// Validate platform support before attempting to start
|
||||
if err := CheckPlatformSupport(); err != nil {
|
||||
if err := imagegen.CheckPlatformSupport(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -70,8 +71,8 @@ func NewServer(modelName string) (*Server, error) {
|
||||
exe = eval
|
||||
}
|
||||
|
||||
// Spawn subprocess: ollama runner --image-engine --model <path> --port <port>
|
||||
cmd := exec.Command(exe, "runner", "--image-engine", "--model", modelName, "--port", strconv.Itoa(port))
|
||||
// Spawn subprocess: ollama runner --mlx-engine --model <path> --port <port>
|
||||
cmd := exec.Command(exe, "runner", "--mlx-engine", "--model", modelName, "--port", strconv.Itoa(port))
|
||||
cmd.Env = os.Environ()
|
||||
|
||||
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
|
||||
@@ -104,11 +105,21 @@ func NewServer(modelName string) (*Server, error) {
|
||||
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
|
||||
}
|
||||
|
||||
// Estimate VRAM based on tensor size from manifest
|
||||
var vramSize uint64
|
||||
if manifest, err := imagegen.LoadManifest(modelName); err == nil {
|
||||
vramSize = uint64(manifest.TotalTensorSize())
|
||||
} else {
|
||||
// Fallback: default to 8GB if manifest can't be loaded
|
||||
vramSize = 8 * 1024 * 1024 * 1024
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
cmd: cmd,
|
||||
port: port,
|
||||
modelName: modelName,
|
||||
vramSize: EstimateVRAM(modelName),
|
||||
mode: mode,
|
||||
vramSize: vramSize,
|
||||
done: make(chan error, 1),
|
||||
client: &http.Client{Timeout: 10 * time.Minute},
|
||||
}
|
||||
@@ -119,23 +130,23 @@ func NewServer(modelName string) (*Server, error) {
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(stdout)
|
||||
for scanner.Scan() {
|
||||
slog.Info("image-runner", "msg", scanner.Text())
|
||||
slog.Info("mlx-runner", "msg", scanner.Text())
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
slog.Warn("image-runner", "msg", line)
|
||||
slog.Warn("mlx-runner", "msg", line)
|
||||
s.lastErrLock.Lock()
|
||||
s.lastErr = line
|
||||
s.lastErrLock.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
slog.Info("starting image runner subprocess", "exe", exe, "model", modelName, "port", port)
|
||||
slog.Info("starting mlx runner subprocess", "exe", exe, "model", modelName, "port", port, "mode", mode)
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("failed to start image runner: %w", err)
|
||||
return nil, fmt.Errorf("failed to start mlx runner: %w", err)
|
||||
}
|
||||
|
||||
// Reap subprocess when it exits
|
||||
@@ -158,6 +169,7 @@ func (s *Server) ModelPath() string {
|
||||
return s.modelName
|
||||
}
|
||||
|
||||
// Load satisfies the LlamaServer interface. MLX models don't need GPU layer assignment.
|
||||
func (s *Server) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -193,18 +205,18 @@ func (s *Server) waitUntilRunning() error {
|
||||
// Include recent stderr lines for better error context
|
||||
errMsg := s.getLastErr()
|
||||
if errMsg != "" {
|
||||
return fmt.Errorf("image runner failed: %s (exit: %v)", errMsg, err)
|
||||
return fmt.Errorf("mlx runner failed: %s (exit: %v)", errMsg, err)
|
||||
}
|
||||
return fmt.Errorf("image runner exited unexpectedly: %w", err)
|
||||
return fmt.Errorf("mlx runner exited unexpectedly: %w", err)
|
||||
case <-timeout:
|
||||
errMsg := s.getLastErr()
|
||||
if errMsg != "" {
|
||||
return fmt.Errorf("timeout waiting for image runner: %s", errMsg)
|
||||
return fmt.Errorf("timeout waiting for mlx runner: %s", errMsg)
|
||||
}
|
||||
return errors.New("timeout waiting for image runner to start")
|
||||
return errors.New("timeout waiting for mlx runner to start")
|
||||
case <-ticker.C:
|
||||
if err := s.Ping(ctx); err == nil {
|
||||
slog.Info("image runner is ready", "port", s.port)
|
||||
slog.Info("mlx runner is ready", "port", s.port)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -218,27 +230,43 @@ func (s *Server) getLastErr() string {
|
||||
return s.lastErr
|
||||
}
|
||||
|
||||
func (s *Server) WaitUntilRunning(ctx context.Context) error { return nil }
|
||||
// WaitUntilRunning satisfies the LlamaServer interface.
|
||||
func (s *Server) WaitUntilRunning(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Completion handles both text and image generation requests.
|
||||
func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
seed := req.Seed
|
||||
if seed == 0 {
|
||||
seed = time.Now().UnixNano()
|
||||
}
|
||||
|
||||
// Extract raw image bytes from llm.ImageData slice
|
||||
var images [][]byte
|
||||
for _, img := range req.Images {
|
||||
images = append(images, img.Data)
|
||||
}
|
||||
|
||||
// Build request for subprocess
|
||||
creq := struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int32 `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}{
|
||||
creq := Request{
|
||||
Prompt: req.Prompt,
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
Steps: req.Steps,
|
||||
Steps: int(req.Steps),
|
||||
Seed: seed,
|
||||
Images: images,
|
||||
}
|
||||
|
||||
// Pass LLM options if present
|
||||
if req.Options != nil {
|
||||
creq.Options = &RequestOptions{
|
||||
NumPredict: req.Options.NumPredict,
|
||||
Temperature: float64(req.Options.Temperature),
|
||||
TopP: float64(req.Options.TopP),
|
||||
TopK: req.Options.TopK,
|
||||
Stop: req.Options.Stop,
|
||||
}
|
||||
}
|
||||
|
||||
body, err := json.Marshal(creq)
|
||||
@@ -260,31 +288,47 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("request failed: %d", resp.StatusCode)
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("%s", strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 1024*1024), 16*1024*1024) // 16MB max
|
||||
for scanner.Scan() {
|
||||
// Parse subprocess response (has singular "image" field)
|
||||
// Parse subprocess response
|
||||
var raw struct {
|
||||
Image string `json:"image,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Done bool `json:"done"`
|
||||
Step int `json:"step,omitempty"`
|
||||
Total int `json:"total,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Done bool `json:"done"`
|
||||
Step int `json:"step,omitempty"`
|
||||
Total int `json:"total,omitempty"`
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
|
||||
EvalCount int `json:"eval_count,omitempty"`
|
||||
EvalDuration int `json:"eval_duration,omitempty"`
|
||||
}
|
||||
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
|
||||
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
|
||||
continue
|
||||
}
|
||||
|
||||
// Log stop reason when generation completes
|
||||
if raw.Done && raw.StopReason != "" {
|
||||
slog.Info("mlx generation completed", "stop_reason", raw.StopReason)
|
||||
}
|
||||
|
||||
// Convert to llm.CompletionResponse
|
||||
cresp := llm.CompletionResponse{
|
||||
Content: raw.Content,
|
||||
Done: raw.Done,
|
||||
Step: raw.Step,
|
||||
TotalSteps: raw.Total,
|
||||
Image: raw.Image,
|
||||
Content: raw.Content,
|
||||
Done: raw.Done,
|
||||
Step: raw.Step,
|
||||
TotalSteps: raw.Total,
|
||||
Image: raw.Image,
|
||||
PromptEvalCount: raw.PromptEvalCount,
|
||||
PromptEvalDuration: time.Duration(raw.PromptEvalDuration),
|
||||
EvalCount: raw.EvalCount,
|
||||
EvalDuration: time.Duration(raw.EvalDuration),
|
||||
}
|
||||
|
||||
fn(cresp)
|
||||
@@ -293,7 +337,20 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
||||
}
|
||||
}
|
||||
|
||||
return scanner.Err()
|
||||
// Scanner exited without receiving Done - connection was likely closed
|
||||
scanErr := scanner.Err()
|
||||
if scanErr != nil {
|
||||
slog.Error("mlx scanner error", "error", scanErr)
|
||||
} else {
|
||||
slog.Warn("mlx scanner EOF without Done response - subprocess may have crashed")
|
||||
}
|
||||
|
||||
// Check if subprocess is still alive
|
||||
if s.HasExited() {
|
||||
slog.Error("mlx subprocess has exited unexpectedly")
|
||||
}
|
||||
|
||||
return scanErr
|
||||
}
|
||||
|
||||
// Close terminates the subprocess.
|
||||
@@ -302,7 +359,7 @@ func (s *Server) Close() error {
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.cmd != nil && s.cmd.Process != nil {
|
||||
slog.Info("stopping image runner subprocess", "pid", s.cmd.Process.Pid)
|
||||
slog.Info("stopping mlx runner subprocess", "pid", s.cmd.Process.Pid)
|
||||
s.cmd.Process.Signal(os.Interrupt)
|
||||
|
||||
// Wait briefly for graceful shutdown
|
||||
@@ -331,18 +388,51 @@ func (s *Server) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||
return s.vramSize
|
||||
}
|
||||
|
||||
// Embedding returns embeddings for the input.
|
||||
func (s *Server) Embedding(ctx context.Context, input string) ([]float32, int, error) {
|
||||
return nil, 0, errors.New("not supported")
|
||||
return nil, 0, errors.New("embeddings not supported for MLX models")
|
||||
}
|
||||
|
||||
// Tokenize tokenizes the input content.
|
||||
func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
return nil, errors.New("not supported")
|
||||
body, err := json.Marshal(map[string]string{"content": content})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("http://127.0.0.1:%d/tokenize", s.port)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("tokenize failed: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Tokens []int `json:"tokens"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result.Tokens, nil
|
||||
}
|
||||
|
||||
// Detokenize converts tokens back to text.
|
||||
func (s *Server) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||
return "", errors.New("not supported")
|
||||
return "", errors.New("detokenization not supported for MLX models")
|
||||
}
|
||||
|
||||
// Pid returns the process ID of the subprocess.
|
||||
func (s *Server) Pid() int {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -352,9 +442,17 @@ func (s *Server) Pid() int {
|
||||
return -1
|
||||
}
|
||||
|
||||
func (s *Server) GetPort() int { return s.port }
|
||||
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil }
|
||||
// GetPort returns the port the subprocess is listening on.
|
||||
func (s *Server) GetPort() int {
|
||||
return s.port
|
||||
}
|
||||
|
||||
// GetDeviceInfos returns device information.
|
||||
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasExited returns whether the subprocess has exited.
|
||||
func (s *Server) HasExited() bool {
|
||||
select {
|
||||
case <-s.done:
|
||||
81
x/mlxrunner/types.go
Normal file
81
x/mlxrunner/types.go
Normal file
@@ -0,0 +1,81 @@
|
||||
// Package mlxrunner provides a unified MLX runner for both LLM and image generation models.
|
||||
//
|
||||
// This package handles safetensors models created with `ollama create --experimental`,
|
||||
// supporting both text generation (LLM) and image generation (diffusion) models
|
||||
// through a single unified interface.
|
||||
package mlxrunner
|
||||
|
||||
// Request is the request format for completion requests.
|
||||
type Request struct {
|
||||
Prompt string `json:"prompt"`
|
||||
|
||||
// LLM-specific fields
|
||||
Options *RequestOptions `json:"options,omitempty"`
|
||||
|
||||
// Image generation fields
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
Images [][]byte `json:"images,omitempty"` // Input images for image editing/conditioning
|
||||
}
|
||||
|
||||
// RequestOptions contains LLM-specific generation options.
|
||||
type RequestOptions struct {
|
||||
NumPredict int `json:"num_predict,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
}
|
||||
|
||||
// Response is streamed back for each progress update.
|
||||
type Response struct {
|
||||
// Text generation response
|
||||
Content string `json:"content,omitempty"`
|
||||
|
||||
// Image generation response
|
||||
Image string `json:"image,omitempty"` // Base64-encoded PNG
|
||||
|
||||
// Common fields
|
||||
Done bool `json:"done"`
|
||||
DoneReason int `json:"done_reason,omitempty"`
|
||||
StopReason string `json:"stop_reason,omitempty"` // Debug: why generation stopped
|
||||
|
||||
// Progress fields
|
||||
Step int `json:"step,omitempty"`
|
||||
Total int `json:"total,omitempty"`
|
||||
|
||||
// Statistics
|
||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
|
||||
EvalCount int `json:"eval_count,omitempty"`
|
||||
EvalDuration int `json:"eval_duration,omitempty"`
|
||||
}
|
||||
|
||||
// HealthResponse is returned by the health endpoint.
|
||||
type HealthResponse struct {
|
||||
Status string `json:"status"`
|
||||
Progress float32 `json:"progress,omitempty"`
|
||||
}
|
||||
|
||||
// ModelMode represents the type of model being run.
|
||||
type ModelMode int
|
||||
|
||||
const (
|
||||
// ModeLLM indicates a text generation model.
|
||||
ModeLLM ModelMode = iota
|
||||
// ModeImageGen indicates an image generation model.
|
||||
ModeImageGen
|
||||
)
|
||||
|
||||
func (m ModelMode) String() string {
|
||||
switch m {
|
||||
case ModeLLM:
|
||||
return "llm"
|
||||
case ModeImageGen:
|
||||
return "imagegen"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
@@ -87,7 +87,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
// m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
|
||||
|
||||
// TODO need to implement sliding window...
|
||||
m.Cache = kvcache.NewMLXCausalCache()
|
||||
m.Cache = kvcache.NewCausalCache()
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
127
x/server/show.go
127
x/server/show.go
@@ -163,9 +163,18 @@ func GetSafetensorsTensorInfo(name model.Name) ([]api.Tensor, error) {
|
||||
|
||||
// getTensorInfoFromManifest extracts tensor info from a manifest.
|
||||
// This is separated for testability.
|
||||
// For quantized models, groups weight/scale/qbias into single entries with detected quantization type.
|
||||
func getTensorInfoFromManifest(mf *manifest.Manifest) ([]api.Tensor, error) {
|
||||
var tensors []api.Tensor
|
||||
|
||||
// First pass: collect all tensor info and identify scale tensors
|
||||
type tensorData struct {
|
||||
info *safetensorsTensorInfo
|
||||
digest string
|
||||
}
|
||||
tensorMap := make(map[string]*tensorData)
|
||||
scaleMap := make(map[string]*tensorData) // base name -> scale tensor info
|
||||
|
||||
for _, layer := range mf.Layers {
|
||||
if layer.MediaType != manifest.MediaTypeImageTensor {
|
||||
continue
|
||||
@@ -178,28 +187,96 @@ func getTensorInfoFromManifest(mf *manifest.Manifest) ([]api.Tensor, error) {
|
||||
}
|
||||
info, err := readSafetensorsHeader(blobPath)
|
||||
if err != nil {
|
||||
// Skip tensors we can't read
|
||||
continue
|
||||
}
|
||||
|
||||
// Convert shape from int to uint64
|
||||
shape := make([]uint64, len(info.Shape))
|
||||
for i, s := range info.Shape {
|
||||
shape[i] = uint64(s)
|
||||
td := &tensorData{info: info, digest: layer.Digest}
|
||||
|
||||
if strings.HasSuffix(layer.Name, "_scale") {
|
||||
baseName := strings.TrimSuffix(layer.Name, "_scale")
|
||||
scaleMap[baseName] = td
|
||||
} else if strings.HasSuffix(layer.Name, "_qbias") {
|
||||
// Skip qbias tensors - they're included with the quantized weight
|
||||
continue
|
||||
} else {
|
||||
tensorMap[layer.Name] = td
|
||||
}
|
||||
}
|
||||
|
||||
// Second pass: build tensor list with quantization info
|
||||
for _, layer := range mf.Layers {
|
||||
if layer.MediaType != manifest.MediaTypeImageTensor {
|
||||
continue
|
||||
}
|
||||
|
||||
tensors = append(tensors, api.Tensor{
|
||||
Name: layer.Name,
|
||||
Type: info.Dtype,
|
||||
Shape: shape,
|
||||
})
|
||||
// Skip scale and qbias tensors
|
||||
if strings.HasSuffix(layer.Name, "_scale") || strings.HasSuffix(layer.Name, "_qbias") {
|
||||
continue
|
||||
}
|
||||
|
||||
td := tensorMap[layer.Name]
|
||||
if td == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this tensor has a corresponding scale tensor (quantized)
|
||||
scaleTd := scaleMap[layer.Name]
|
||||
if scaleTd != nil && len(td.info.Shape) >= 2 && len(scaleTd.info.Shape) >= 2 {
|
||||
// Quantized tensor - detect bits from shapes
|
||||
weightCols := td.info.Shape[len(td.info.Shape)-1]
|
||||
scaleCols := scaleTd.info.Shape[len(scaleTd.info.Shape)-1]
|
||||
|
||||
// Detect quantization: Q4 has pack_factor=8, Q8 has pack_factor=4
|
||||
// Q4 uses group_size=32: weightCols * 8 / scaleCols = 32
|
||||
// Q8 uses group_size=64: weightCols * 4 / scaleCols = 64
|
||||
var bits int
|
||||
var quantType string
|
||||
if weightCols*8/scaleCols == 32 {
|
||||
bits = 4
|
||||
quantType = "Q4"
|
||||
} else if weightCols*4/scaleCols == 64 {
|
||||
bits = 8
|
||||
quantType = "Q8"
|
||||
} else {
|
||||
// Unknown quantization, show raw
|
||||
quantType = td.info.Dtype
|
||||
}
|
||||
|
||||
// Calculate unpacked shape
|
||||
shape := make([]uint64, len(td.info.Shape))
|
||||
for i, s := range td.info.Shape {
|
||||
shape[i] = uint64(s)
|
||||
}
|
||||
if bits > 0 {
|
||||
packFactor := int64(32 / bits)
|
||||
shape[len(shape)-1] = uint64(td.info.Shape[len(td.info.Shape)-1] * packFactor)
|
||||
}
|
||||
|
||||
tensors = append(tensors, api.Tensor{
|
||||
Name: layer.Name,
|
||||
Type: quantType,
|
||||
Shape: shape,
|
||||
})
|
||||
} else {
|
||||
// Non-quantized tensor
|
||||
shape := make([]uint64, len(td.info.Shape))
|
||||
for i, s := range td.info.Shape {
|
||||
shape[i] = uint64(s)
|
||||
}
|
||||
|
||||
tensors = append(tensors, api.Tensor{
|
||||
Name: layer.Name,
|
||||
Type: td.info.Dtype,
|
||||
Shape: shape,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return tensors, nil
|
||||
}
|
||||
|
||||
// GetSafetensorsDtype returns the quantization type for a safetensors model.
|
||||
// If the model is quantized (has _scale tensors), returns the quantization type (e.g., "FP8").
|
||||
// Reads from model_index.json first, falls back to detection from tensor names.
|
||||
// Otherwise returns the torch_dtype from config.json.
|
||||
func GetSafetensorsDtype(name model.Name) (string, error) {
|
||||
mf, err := manifest.ParseNamedManifest(name)
|
||||
@@ -207,16 +284,38 @@ func GetSafetensorsDtype(name model.Name) (string, error) {
|
||||
return "", fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
// Check if model is quantized by looking for _scale tensors
|
||||
// First try to read quantization from model_index.json
|
||||
var modelIndex struct {
|
||||
Quantization string `json:"quantization"`
|
||||
}
|
||||
if err := mf.ReadConfigJSON("model_index.json", &modelIndex); err == nil && modelIndex.Quantization != "" {
|
||||
return modelIndex.Quantization, nil
|
||||
}
|
||||
|
||||
// Fallback: detect from tensor names
|
||||
hasScales := false
|
||||
hasQBias := false
|
||||
for _, layer := range mf.Layers {
|
||||
if layer.MediaType == manifest.MediaTypeImageTensor {
|
||||
if strings.HasSuffix(layer.Name, "_scale") {
|
||||
// Model is quantized - return FP8 (affine quantization)
|
||||
return "FP8", nil
|
||||
hasScales = true
|
||||
}
|
||||
if strings.HasSuffix(layer.Name, "_qbias") {
|
||||
hasQBias = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasScales {
|
||||
if hasQBias {
|
||||
// Affine mode (has scale + qbias) - could be Q4 or Q8
|
||||
// Default to Q4 as it's more common
|
||||
return "Q4", nil
|
||||
}
|
||||
// No qbias = NVFP4
|
||||
return "NVFP4", nil
|
||||
}
|
||||
|
||||
// Not quantized - return torch_dtype from config.json
|
||||
var cfg struct {
|
||||
TorchDtype string `json:"torch_dtype"`
|
||||
|
||||
Reference in New Issue
Block a user