cmd: ollama config command to help configure integrations to use Ollama (#13712)

This commit is contained in:
Parth Sareen
2026-01-22 23:17:11 -05:00
committed by GitHub
parent 3b3bf6c217
commit 199c41e16e
17 changed files with 4487 additions and 0 deletions

View File

@@ -35,6 +35,7 @@ import (
"golang.org/x/term"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser"
@@ -2026,6 +2027,7 @@ func NewCLI() *cobra.Command {
copyCmd,
deleteCmd,
runnerCmd,
config.ConfigCmd(checkServerHeartbeat),
)
return rootCmd

36
cmd/config/claude.go Normal file
View File

@@ -0,0 +1,36 @@
package config
import (
"fmt"
"os"
"os/exec"
)
// Claude implements Runner for Claude Code integration
type Claude struct{}
func (c *Claude) String() string { return "Claude Code" }
func (c *Claude) args(model string) []string {
if model != "" {
return []string{"--model", model}
}
return nil
}
func (c *Claude) Run(model string) error {
if _, err := exec.LookPath("claude"); err != nil {
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
}
cmd := exec.Command("claude", c.args(model)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = append(os.Environ(),
"ANTHROPIC_BASE_URL=http://localhost:11434",
"ANTHROPIC_API_KEY=",
"ANTHROPIC_AUTH_TOKEN=ollama",
)
return cmd.Run()
}

42
cmd/config/claude_test.go Normal file
View File

@@ -0,0 +1,42 @@
package config
import (
"slices"
"testing"
)
func TestClaudeIntegration(t *testing.T) {
c := &Claude{}
t.Run("String", func(t *testing.T) {
if got := c.String(); got != "Claude Code" {
t.Errorf("String() = %q, want %q", got, "Claude Code")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = c
})
}
func TestClaudeArgs(t *testing.T) {
c := &Claude{}
tests := []struct {
name string
model string
want []string
}{
{"with model", "llama3.2", []string{"--model", "llama3.2"}},
{"empty model", "", nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := c.args(tt.model)
if !slices.Equal(got, tt.want) {
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
}
})
}
}

61
cmd/config/codex.go Normal file
View File

@@ -0,0 +1,61 @@
package config
import (
"fmt"
"os"
"os/exec"
"strings"
"golang.org/x/mod/semver"
)
// Codex implements Runner for Codex integration
type Codex struct{}
func (c *Codex) String() string { return "Codex" }
func (c *Codex) args(model string) []string {
args := []string{"--oss"}
if model != "" {
args = append(args, "-m", model)
}
return args
}
func (c *Codex) Run(model string) error {
if err := checkCodexVersion(); err != nil {
return err
}
cmd := exec.Command("codex", c.args(model)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func checkCodexVersion() error {
if _, err := exec.LookPath("codex"); err != nil {
return fmt.Errorf("codex is not installed, install with: npm install -g @openai/codex")
}
out, err := exec.Command("codex", "--version").Output()
if err != nil {
return fmt.Errorf("failed to get codex version: %w", err)
}
// Parse output like "codex-cli 0.87.0"
fields := strings.Fields(strings.TrimSpace(string(out)))
if len(fields) < 2 {
return fmt.Errorf("unexpected codex version output: %s", string(out))
}
version := "v" + fields[len(fields)-1]
minVersion := "v0.81.0"
if semver.Compare(version, minVersion) < 0 {
return fmt.Errorf("codex version %s is too old, minimum required is %s, update with: npm update -g @openai/codex", fields[len(fields)-1], "0.81.0")
}
return nil
}

28
cmd/config/codex_test.go Normal file
View File

@@ -0,0 +1,28 @@
package config
import (
"slices"
"testing"
)
func TestCodexArgs(t *testing.T) {
c := &Codex{}
tests := []struct {
name string
model string
want []string
}{
{"with model", "llama3.2", []string{"--oss", "-m", "llama3.2"}},
{"empty model", "", []string{"--oss"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := c.args(tt.model)
if !slices.Equal(got, tt.want) {
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
}
})
}
}

115
cmd/config/config.go Normal file
View File

@@ -0,0 +1,115 @@
// Package config provides integration configuration for external coding tools
// (Claude Code, Codex, Droid, OpenCode) to use Ollama models.
package config
import (
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
)
type integration struct {
Models []string `json:"models"`
}
type config struct {
Integrations map[string]*integration `json:"integrations"`
}
func configPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".ollama", "config", "config.json"), nil
}
func load() (*config, error) {
path, err := configPath()
if err != nil {
return nil, err
}
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return &config{Integrations: make(map[string]*integration)}, nil
}
return nil, err
}
var cfg config
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("failed to parse config: %w, at: %s", err, path)
}
if cfg.Integrations == nil {
cfg.Integrations = make(map[string]*integration)
}
return &cfg, nil
}
func save(cfg *config) error {
path, err := configPath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return err
}
data, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return err
}
return writeWithBackup(path, data)
}
func saveIntegration(appName string, models []string) error {
if appName == "" {
return errors.New("app name cannot be empty")
}
cfg, err := load()
if err != nil {
return err
}
cfg.Integrations[strings.ToLower(appName)] = &integration{
Models: models,
}
return save(cfg)
}
func loadIntegration(appName string) (*integration, error) {
cfg, err := load()
if err != nil {
return nil, err
}
ic, ok := cfg.Integrations[strings.ToLower(appName)]
if !ok {
return nil, os.ErrNotExist
}
return ic, nil
}
func listIntegrations() ([]integration, error) {
cfg, err := load()
if err != nil {
return nil, err
}
result := make([]integration, 0, len(cfg.Integrations))
for _, ic := range cfg.Integrations {
result = append(result, *ic)
}
return result, nil
}

373
cmd/config/config_test.go Normal file
View File

@@ -0,0 +1,373 @@
package config
import (
"os"
"path/filepath"
"strings"
"testing"
)
// setTestHome sets both HOME (Unix) and USERPROFILE (Windows) for cross-platform tests
func setTestHome(t *testing.T, dir string) {
t.Setenv("HOME", dir)
t.Setenv("USERPROFILE", dir)
}
// editorPaths is a test helper that safely calls Paths if the runner implements Editor
func editorPaths(r Runner) []string {
if editor, ok := r.(Editor); ok {
return editor.Paths()
}
return nil
}
func TestIntegrationConfig(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("save and load round-trip", func(t *testing.T) {
models := []string{"llama3.2", "mistral", "qwen2.5"}
if err := saveIntegration("claude", models); err != nil {
t.Fatal(err)
}
config, err := loadIntegration("claude")
if err != nil {
t.Fatal(err)
}
if len(config.Models) != len(models) {
t.Errorf("expected %d models, got %d", len(models), len(config.Models))
}
for i, m := range models {
if config.Models[i] != m {
t.Errorf("model %d: expected %s, got %s", i, m, config.Models[i])
}
}
})
t.Run("defaultModel returns first model", func(t *testing.T) {
saveIntegration("codex", []string{"model-a", "model-b"})
config, _ := loadIntegration("codex")
defaultModel := ""
if len(config.Models) > 0 {
defaultModel = config.Models[0]
}
if defaultModel != "model-a" {
t.Errorf("expected model-a, got %s", defaultModel)
}
})
t.Run("defaultModel returns empty for no models", func(t *testing.T) {
config := &integration{Models: []string{}}
defaultModel := ""
if len(config.Models) > 0 {
defaultModel = config.Models[0]
}
if defaultModel != "" {
t.Errorf("expected empty string, got %s", defaultModel)
}
})
t.Run("app name is case-insensitive", func(t *testing.T) {
saveIntegration("Claude", []string{"model-x"})
config, err := loadIntegration("claude")
if err != nil {
t.Fatal(err)
}
defaultModel := ""
if len(config.Models) > 0 {
defaultModel = config.Models[0]
}
if defaultModel != "model-x" {
t.Errorf("expected model-x, got %s", defaultModel)
}
})
t.Run("multiple integrations in single file", func(t *testing.T) {
saveIntegration("app1", []string{"model-1"})
saveIntegration("app2", []string{"model-2"})
config1, _ := loadIntegration("app1")
config2, _ := loadIntegration("app2")
defaultModel1 := ""
if len(config1.Models) > 0 {
defaultModel1 = config1.Models[0]
}
defaultModel2 := ""
if len(config2.Models) > 0 {
defaultModel2 = config2.Models[0]
}
if defaultModel1 != "model-1" {
t.Errorf("expected model-1, got %s", defaultModel1)
}
if defaultModel2 != "model-2" {
t.Errorf("expected model-2, got %s", defaultModel2)
}
})
}
func TestListIntegrations(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("returns empty when no integrations", func(t *testing.T) {
configs, err := listIntegrations()
if err != nil {
t.Fatal(err)
}
if len(configs) != 0 {
t.Errorf("expected 0 integrations, got %d", len(configs))
}
})
t.Run("returns all saved integrations", func(t *testing.T) {
saveIntegration("claude", []string{"model-1"})
saveIntegration("droid", []string{"model-2"})
configs, err := listIntegrations()
if err != nil {
t.Fatal(err)
}
if len(configs) != 2 {
t.Errorf("expected 2 integrations, got %d", len(configs))
}
})
}
func TestEditorPaths(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("returns empty for claude (no Editor)", func(t *testing.T) {
r := integrations["claude"]
paths := editorPaths(r)
if len(paths) != 0 {
t.Errorf("expected no paths for claude, got %v", paths)
}
})
t.Run("returns empty for codex (no Editor)", func(t *testing.T) {
r := integrations["codex"]
paths := editorPaths(r)
if len(paths) != 0 {
t.Errorf("expected no paths for codex, got %v", paths)
}
})
t.Run("returns empty for droid when no config exists", func(t *testing.T) {
r := integrations["droid"]
paths := editorPaths(r)
if len(paths) != 0 {
t.Errorf("expected no paths, got %v", paths)
}
})
t.Run("returns path for droid when config exists", func(t *testing.T) {
settingsDir, _ := os.UserHomeDir()
settingsDir = filepath.Join(settingsDir, ".factory")
os.MkdirAll(settingsDir, 0o755)
os.WriteFile(filepath.Join(settingsDir, "settings.json"), []byte(`{}`), 0o644)
r := integrations["droid"]
paths := editorPaths(r)
if len(paths) != 1 {
t.Errorf("expected 1 path, got %d", len(paths))
}
})
t.Run("returns paths for opencode when configs exist", func(t *testing.T) {
home, _ := os.UserHomeDir()
configDir := filepath.Join(home, ".config", "opencode")
stateDir := filepath.Join(home, ".local", "state", "opencode")
os.MkdirAll(configDir, 0o755)
os.MkdirAll(stateDir, 0o755)
os.WriteFile(filepath.Join(configDir, "opencode.json"), []byte(`{}`), 0o644)
os.WriteFile(filepath.Join(stateDir, "model.json"), []byte(`{}`), 0o644)
r := integrations["opencode"]
paths := editorPaths(r)
if len(paths) != 2 {
t.Errorf("expected 2 paths, got %d: %v", len(paths), paths)
}
})
}
func TestLoadIntegration_CorruptedJSON(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Create corrupted config.json file
dir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(dir, 0o755)
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
// Corrupted file is treated as empty, so loadIntegration returns not found
_, err := loadIntegration("test")
if err == nil {
t.Error("expected error for nonexistent integration in corrupted file")
}
}
func TestSaveIntegration_NilModels(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
if err := saveIntegration("test", nil); err != nil {
t.Fatalf("saveIntegration with nil models failed: %v", err)
}
config, err := loadIntegration("test")
if err != nil {
t.Fatalf("loadIntegration failed: %v", err)
}
if config.Models == nil {
// nil is acceptable
} else if len(config.Models) != 0 {
t.Errorf("expected empty or nil models, got %v", config.Models)
}
}
func TestSaveIntegration_EmptyAppName(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
err := saveIntegration("", []string{"model"})
if err == nil {
t.Error("expected error for empty app name, got nil")
}
if err != nil && !strings.Contains(err.Error(), "app name cannot be empty") {
t.Errorf("expected 'app name cannot be empty' error, got: %v", err)
}
}
func TestLoadIntegration_NonexistentIntegration(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
_, err := loadIntegration("nonexistent")
if err == nil {
t.Error("expected error for nonexistent integration, got nil")
}
if !os.IsNotExist(err) {
t.Logf("error type is os.ErrNotExist as expected: %v", err)
}
}
func TestConfigPath(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
path, err := configPath()
if err != nil {
t.Fatal(err)
}
expected := filepath.Join(tmpDir, ".ollama", "config", "config.json")
if path != expected {
t.Errorf("expected %s, got %s", expected, path)
}
}
func TestLoad(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("returns empty config when file does not exist", func(t *testing.T) {
cfg, err := load()
if err != nil {
t.Fatal(err)
}
if cfg == nil {
t.Fatal("expected non-nil config")
}
if cfg.Integrations == nil {
t.Error("expected non-nil Integrations map")
}
if len(cfg.Integrations) != 0 {
t.Errorf("expected empty Integrations, got %d", len(cfg.Integrations))
}
})
t.Run("loads existing config", func(t *testing.T) {
path, _ := configPath()
os.MkdirAll(filepath.Dir(path), 0o755)
os.WriteFile(path, []byte(`{"integrations":{"test":{"models":["model-a"]}}}`), 0o644)
cfg, err := load()
if err != nil {
t.Fatal(err)
}
if cfg.Integrations["test"] == nil {
t.Fatal("expected test integration")
}
if len(cfg.Integrations["test"].Models) != 1 {
t.Errorf("expected 1 model, got %d", len(cfg.Integrations["test"].Models))
}
})
t.Run("returns error for corrupted JSON", func(t *testing.T) {
path, _ := configPath()
os.MkdirAll(filepath.Dir(path), 0o755)
os.WriteFile(path, []byte(`{corrupted`), 0o644)
_, err := load()
if err == nil {
t.Error("expected error for corrupted JSON")
}
})
}
func TestSave(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("creates config file", func(t *testing.T) {
cfg := &config{
Integrations: map[string]*integration{
"test": {Models: []string{"model-a", "model-b"}},
},
}
if err := save(cfg); err != nil {
t.Fatal(err)
}
path, _ := configPath()
if _, err := os.Stat(path); os.IsNotExist(err) {
t.Error("config file was not created")
}
})
t.Run("round-trip preserves data", func(t *testing.T) {
cfg := &config{
Integrations: map[string]*integration{
"claude": {Models: []string{"llama3.2", "mistral"}},
"codex": {Models: []string{"qwen2.5"}},
},
}
if err := save(cfg); err != nil {
t.Fatal(err)
}
loaded, err := load()
if err != nil {
t.Fatal(err)
}
if len(loaded.Integrations) != 2 {
t.Errorf("expected 2 integrations, got %d", len(loaded.Integrations))
}
if loaded.Integrations["claude"] == nil {
t.Error("missing claude integration")
}
if len(loaded.Integrations["claude"].Models) != 2 {
t.Errorf("expected 2 models for claude, got %d", len(loaded.Integrations["claude"].Models))
}
})
}

174
cmd/config/droid.go Normal file
View File

@@ -0,0 +1,174 @@
package config
import (
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"slices"
"strings"
)
// Droid implements Runner and Editor for Droid integration
type Droid struct{}
// droidModelEntry represents a custom model entry in Droid's settings.json
type droidModelEntry struct {
Model string `json:"model"`
DisplayName string `json:"displayName"`
BaseURL string `json:"baseUrl"`
APIKey string `json:"apiKey"`
Provider string `json:"provider"`
MaxOutputTokens int `json:"maxOutputTokens"`
SupportsImages bool `json:"supportsImages"`
ID string `json:"id"`
Index int `json:"index"`
}
func (d *Droid) String() string { return "Droid" }
func (d *Droid) Run(model string) error {
if _, err := exec.LookPath("droid"); err != nil {
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
}
// Call Edit() to ensure config is up-to-date before launch
models := []string{model}
if config, err := loadIntegration("droid"); err == nil && len(config.Models) > 0 {
models = config.Models
}
if err := d.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("droid")
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func (d *Droid) Paths() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
p := filepath.Join(home, ".factory", "settings.json")
if _, err := os.Stat(p); err == nil {
return []string{p}
}
return nil
}
func (d *Droid) Edit(models []string) error {
if len(models) == 0 {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
settingsPath := filepath.Join(home, ".factory", "settings.json")
if err := os.MkdirAll(filepath.Dir(settingsPath), 0o755); err != nil {
return err
}
settings := make(map[string]any)
if data, err := os.ReadFile(settingsPath); err == nil {
if err := json.Unmarshal(data, &settings); err != nil {
return fmt.Errorf("failed to parse settings file: %w, at: %s", err, settingsPath)
}
}
customModels, _ := settings["customModels"].([]any)
// Keep only non-Ollama models (we'll rebuild Ollama models fresh)
nonOllamaModels := slices.DeleteFunc(slices.Clone(customModels), isOllamaModelEntry)
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
var ollamaModels []any
var defaultModelID string
for i, model := range models {
modelID := fmt.Sprintf("custom:%s-[Ollama]-%d", model, i)
ollamaModels = append(ollamaModels, droidModelEntry{
Model: model,
DisplayName: model,
BaseURL: "http://localhost:11434/v1",
APIKey: "ollama",
Provider: "generic-chat-completion-api",
MaxOutputTokens: 64000,
SupportsImages: false,
ID: modelID,
Index: i,
})
if i == 0 {
defaultModelID = modelID
}
}
settings["customModels"] = append(ollamaModels, nonOllamaModels...)
sessionSettings, ok := settings["sessionDefaultSettings"].(map[string]any)
if !ok {
sessionSettings = make(map[string]any)
}
sessionSettings["model"] = defaultModelID
if effort, ok := sessionSettings["reasoningEffort"].(string); !ok || !isValidReasoningEffort(effort) {
sessionSettings["reasoningEffort"] = "none"
}
settings["sessionDefaultSettings"] = sessionSettings
data, err := json.MarshalIndent(settings, "", " ")
if err != nil {
return err
}
return writeWithBackup(settingsPath, data)
}
func (d *Droid) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
settings, err := readJSONFile(filepath.Join(home, ".factory", "settings.json"))
if err != nil {
return nil
}
customModels, _ := settings["customModels"].([]any)
var result []string
for _, m := range customModels {
if !isOllamaModelEntry(m) {
continue
}
entry, ok := m.(map[string]any)
if !ok {
continue
}
if model, _ := entry["model"].(string); model != "" {
result = append(result, model)
}
}
return result
}
var validReasoningEfforts = []string{"high", "medium", "low", "none"}
func isValidReasoningEffort(effort string) bool {
return slices.Contains(validReasoningEfforts, effort)
}
func isOllamaModelEntry(m any) bool {
entry, ok := m.(map[string]any)
if !ok {
return false
}
id, _ := entry["id"].(string)
return strings.Contains(id, "-[Ollama]-")
}

454
cmd/config/droid_test.go Normal file
View File

@@ -0,0 +1,454 @@
package config
import (
"encoding/json"
"os"
"path/filepath"
"testing"
)
func TestDroidIntegration(t *testing.T) {
d := &Droid{}
t.Run("String", func(t *testing.T) {
if got := d.String(); got != "Droid" {
t.Errorf("String() = %q, want %q", got, "Droid")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = d
})
t.Run("implements Editor", func(t *testing.T) {
var _ Editor = d
})
}
func TestDroidEdit(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
settingsDir := filepath.Join(tmpDir, ".factory")
settingsPath := filepath.Join(settingsDir, "settings.json")
cleanup := func() {
os.RemoveAll(settingsDir)
}
readSettings := func() map[string]any {
data, _ := os.ReadFile(settingsPath)
var settings map[string]any
json.Unmarshal(data, &settings)
return settings
}
getCustomModels := func(settings map[string]any) []map[string]any {
models, ok := settings["customModels"].([]any)
if !ok {
return nil
}
var result []map[string]any
for _, m := range models {
if entry, ok := m.(map[string]any); ok {
result = append(result, entry)
}
}
return result
}
t.Run("fresh install creates models with sequential indices", func(t *testing.T) {
cleanup()
if err := d.Edit([]string{"model-a", "model-b"}); err != nil {
t.Fatal(err)
}
settings := readSettings()
models := getCustomModels(settings)
if len(models) != 2 {
t.Fatalf("expected 2 models, got %d", len(models))
}
// Check first model
if models[0]["model"] != "model-a" {
t.Errorf("expected model-a, got %s", models[0]["model"])
}
if models[0]["id"] != "custom:model-a-[Ollama]-0" {
t.Errorf("expected custom:model-a-[Ollama]-0, got %s", models[0]["id"])
}
if models[0]["index"] != float64(0) {
t.Errorf("expected index 0, got %v", models[0]["index"])
}
// Check second model
if models[1]["model"] != "model-b" {
t.Errorf("expected model-b, got %s", models[1]["model"])
}
if models[1]["id"] != "custom:model-b-[Ollama]-1" {
t.Errorf("expected custom:model-b-[Ollama]-1, got %s", models[1]["id"])
}
if models[1]["index"] != float64(1) {
t.Errorf("expected index 1, got %v", models[1]["index"])
}
})
t.Run("sets sessionDefaultSettings.model to first model ID", func(t *testing.T) {
cleanup()
if err := d.Edit([]string{"model-a", "model-b"}); err != nil {
t.Fatal(err)
}
settings := readSettings()
session, ok := settings["sessionDefaultSettings"].(map[string]any)
if !ok {
t.Fatal("sessionDefaultSettings not found")
}
if session["model"] != "custom:model-a-[Ollama]-0" {
t.Errorf("expected custom:model-a-[Ollama]-0, got %s", session["model"])
}
})
t.Run("re-indexes when models removed", func(t *testing.T) {
cleanup()
// Add three models
d.Edit([]string{"model-a", "model-b", "model-c"})
// Remove middle model
d.Edit([]string{"model-a", "model-c"})
settings := readSettings()
models := getCustomModels(settings)
if len(models) != 2 {
t.Fatalf("expected 2 models, got %d", len(models))
}
// Check indices are sequential 0, 1
if models[0]["index"] != float64(0) {
t.Errorf("expected index 0, got %v", models[0]["index"])
}
if models[1]["index"] != float64(1) {
t.Errorf("expected index 1, got %v", models[1]["index"])
}
// Check IDs match new indices
if models[0]["id"] != "custom:model-a-[Ollama]-0" {
t.Errorf("expected custom:model-a-[Ollama]-0, got %s", models[0]["id"])
}
if models[1]["id"] != "custom:model-c-[Ollama]-1" {
t.Errorf("expected custom:model-c-[Ollama]-1, got %s", models[1]["id"])
}
})
t.Run("preserves non-Ollama custom models", func(t *testing.T) {
cleanup()
os.MkdirAll(settingsDir, 0o755)
// Pre-existing non-Ollama model
os.WriteFile(settingsPath, []byte(`{
"customModels": [
{"model": "gpt-4", "displayName": "GPT-4", "provider": "openai"}
]
}`), 0o644)
d.Edit([]string{"model-a"})
settings := readSettings()
models := getCustomModels(settings)
if len(models) != 2 {
t.Fatalf("expected 2 models (1 Ollama + 1 non-Ollama), got %d", len(models))
}
// Ollama model should be first
if models[0]["model"] != "model-a" {
t.Errorf("expected Ollama model first, got %s", models[0]["model"])
}
// Non-Ollama model should be preserved at end
if models[1]["model"] != "gpt-4" {
t.Errorf("expected gpt-4 preserved, got %s", models[1]["model"])
}
})
t.Run("preserves other settings", func(t *testing.T) {
cleanup()
os.MkdirAll(settingsDir, 0o755)
os.WriteFile(settingsPath, []byte(`{
"theme": "dark",
"enableHooks": true,
"sessionDefaultSettings": {"autonomyMode": "auto-high"}
}`), 0o644)
d.Edit([]string{"model-a"})
settings := readSettings()
if settings["theme"] != "dark" {
t.Error("theme was not preserved")
}
if settings["enableHooks"] != true {
t.Error("enableHooks was not preserved")
}
session := settings["sessionDefaultSettings"].(map[string]any)
if session["autonomyMode"] != "auto-high" {
t.Error("autonomyMode was not preserved")
}
})
t.Run("required fields present", func(t *testing.T) {
cleanup()
d.Edit([]string{"test-model"})
settings := readSettings()
models := getCustomModels(settings)
if len(models) != 1 {
t.Fatal("expected 1 model")
}
model := models[0]
requiredFields := []string{"model", "displayName", "baseUrl", "apiKey", "provider", "maxOutputTokens", "id", "index"}
for _, field := range requiredFields {
if model[field] == nil {
t.Errorf("missing required field: %s", field)
}
}
if model["baseUrl"] != "http://localhost:11434/v1" {
t.Errorf("unexpected baseUrl: %s", model["baseUrl"])
}
if model["apiKey"] != "ollama" {
t.Errorf("unexpected apiKey: %s", model["apiKey"])
}
if model["provider"] != "generic-chat-completion-api" {
t.Errorf("unexpected provider: %s", model["provider"])
}
})
t.Run("fixes invalid reasoningEffort", func(t *testing.T) {
cleanup()
os.MkdirAll(settingsDir, 0o755)
// Pre-existing settings with invalid reasoningEffort
os.WriteFile(settingsPath, []byte(`{
"sessionDefaultSettings": {"reasoningEffort": "off"}
}`), 0o644)
d.Edit([]string{"model-a"})
settings := readSettings()
session := settings["sessionDefaultSettings"].(map[string]any)
if session["reasoningEffort"] != "none" {
t.Errorf("expected reasoningEffort to be fixed to 'none', got %s", session["reasoningEffort"])
}
})
t.Run("preserves valid reasoningEffort", func(t *testing.T) {
cleanup()
os.MkdirAll(settingsDir, 0o755)
os.WriteFile(settingsPath, []byte(`{
"sessionDefaultSettings": {"reasoningEffort": "high"}
}`), 0o644)
d.Edit([]string{"model-a"})
settings := readSettings()
session := settings["sessionDefaultSettings"].(map[string]any)
if session["reasoningEffort"] != "high" {
t.Errorf("expected reasoningEffort to remain 'high', got %s", session["reasoningEffort"])
}
})
}
// Edge case tests for droid.go
func TestDroidEdit_CorruptedJSON(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
settingsDir := filepath.Join(tmpDir, ".factory")
settingsPath := filepath.Join(settingsDir, "settings.json")
os.MkdirAll(settingsDir, 0o755)
os.WriteFile(settingsPath, []byte(`{corrupted json content`), 0o644)
// Corrupted JSON should return an error so user knows something is wrong
err := d.Edit([]string{"model-a"})
if err == nil {
t.Fatal("expected error for corrupted JSON, got nil")
}
// Original corrupted file should be preserved (not overwritten)
data, _ := os.ReadFile(settingsPath)
if string(data) != `{corrupted json content` {
t.Errorf("corrupted file was modified: got %s", string(data))
}
}
func TestDroidEdit_WrongTypeCustomModels(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
settingsDir := filepath.Join(tmpDir, ".factory")
settingsPath := filepath.Join(settingsDir, "settings.json")
os.MkdirAll(settingsDir, 0o755)
// customModels is a string instead of array
os.WriteFile(settingsPath, []byte(`{"customModels": "not an array"}`), 0o644)
// Should not panic - wrong type should be handled gracefully
err := d.Edit([]string{"model-a"})
if err != nil {
t.Fatalf("Edit failed with wrong type customModels: %v", err)
}
// Verify models were added correctly
data, _ := os.ReadFile(settingsPath)
var settings map[string]any
json.Unmarshal(data, &settings)
customModels, ok := settings["customModels"].([]any)
if !ok {
t.Fatalf("customModels should be array after setup, got %T", settings["customModels"])
}
if len(customModels) != 1 {
t.Errorf("expected 1 model, got %d", len(customModels))
}
}
func TestDroidEdit_EmptyModels(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
settingsDir := filepath.Join(tmpDir, ".factory")
settingsPath := filepath.Join(settingsDir, "settings.json")
os.MkdirAll(settingsDir, 0o755)
originalContent := `{"customModels": [{"model": "existing"}]}`
os.WriteFile(settingsPath, []byte(originalContent), 0o644)
// Empty models should be no-op
err := d.Edit([]string{})
if err != nil {
t.Fatalf("Edit with empty models failed: %v", err)
}
// Original content should be preserved (file not modified)
data, _ := os.ReadFile(settingsPath)
if string(data) != originalContent {
t.Errorf("empty models should not modify file, but content changed")
}
}
func TestDroidEdit_DuplicateModels(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
settingsDir := filepath.Join(tmpDir, ".factory")
settingsPath := filepath.Join(settingsDir, "settings.json")
// Add same model twice
err := d.Edit([]string{"model-a", "model-a"})
if err != nil {
t.Fatalf("Edit with duplicates failed: %v", err)
}
settings, err := readJSONFile(settingsPath)
if err != nil {
t.Fatalf("readJSONFile failed: %v", err)
}
customModels, _ := settings["customModels"].([]any)
// Document current behavior: duplicates are kept as separate entries
if len(customModels) != 2 {
t.Logf("Note: duplicates result in %d entries (documenting behavior)", len(customModels))
}
}
func TestDroidEdit_MalformedModelEntry(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
settingsDir := filepath.Join(tmpDir, ".factory")
settingsPath := filepath.Join(settingsDir, "settings.json")
os.MkdirAll(settingsDir, 0o755)
// Model entry is a string instead of a map
os.WriteFile(settingsPath, []byte(`{"customModels": ["not a map", 123]}`), 0o644)
err := d.Edit([]string{"model-a"})
if err != nil {
t.Fatalf("Edit with malformed entries failed: %v", err)
}
// Malformed entries should be preserved in nonOllamaModels
settings, _ := readJSONFile(settingsPath)
customModels, _ := settings["customModels"].([]any)
// Should have: 1 new Ollama model + 2 preserved malformed entries
if len(customModels) != 3 {
t.Errorf("expected 3 entries (1 new + 2 preserved malformed), got %d", len(customModels))
}
}
func TestDroidEdit_WrongTypeSessionSettings(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
settingsDir := filepath.Join(tmpDir, ".factory")
settingsPath := filepath.Join(settingsDir, "settings.json")
os.MkdirAll(settingsDir, 0o755)
// sessionDefaultSettings is a string instead of map
os.WriteFile(settingsPath, []byte(`{"sessionDefaultSettings": "not a map"}`), 0o644)
err := d.Edit([]string{"model-a"})
if err != nil {
t.Fatalf("Edit with wrong type sessionDefaultSettings failed: %v", err)
}
// Should create proper sessionDefaultSettings
settings, _ := readJSONFile(settingsPath)
session, ok := settings["sessionDefaultSettings"].(map[string]any)
if !ok {
t.Fatalf("sessionDefaultSettings should be map after setup, got %T", settings["sessionDefaultSettings"])
}
if session["model"] == nil {
t.Error("expected model to be set in sessionDefaultSettings")
}
}
func TestIsValidReasoningEffort(t *testing.T) {
tests := []struct {
effort string
valid bool
}{
{"high", true},
{"medium", true},
{"low", true},
{"none", true},
{"off", false},
{"", false},
{"HIGH", false}, // case sensitive
{"max", false},
}
for _, tt := range tests {
t.Run(tt.effort, func(t *testing.T) {
got := isValidReasoningEffort(tt.effort)
if got != tt.valid {
t.Errorf("isValidReasoningEffort(%q) = %v, want %v", tt.effort, got, tt.valid)
}
})
}
}

99
cmd/config/files.go Normal file
View File

@@ -0,0 +1,99 @@
package config
import (
"bytes"
"encoding/json"
"fmt"
"os"
"path/filepath"
"time"
)
func readJSONFile(path string) (map[string]any, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var result map[string]any
if err := json.Unmarshal(data, &result); err != nil {
return nil, err
}
return result, nil
}
func copyFile(src, dst string) error {
info, err := os.Stat(src)
if err != nil {
return err
}
data, err := os.ReadFile(src)
if err != nil {
return err
}
return os.WriteFile(dst, data, info.Mode().Perm())
}
func backupDir() string {
return filepath.Join(os.TempDir(), "ollama-backups")
}
func backupToTmp(srcPath string) (string, error) {
dir := backupDir()
if err := os.MkdirAll(dir, 0o755); err != nil {
return "", err
}
backupPath := filepath.Join(dir, fmt.Sprintf("%s.%d", filepath.Base(srcPath), time.Now().Unix()))
if err := copyFile(srcPath, backupPath); err != nil {
return "", err
}
return backupPath, nil
}
// writeWithBackup writes data to path via temp file + rename, backing up any existing file first
func writeWithBackup(path string, data []byte) error {
var backupPath string
// backup must be created before any writes to the target file
if existingContent, err := os.ReadFile(path); err == nil {
if !bytes.Equal(existingContent, data) {
backupPath, err = backupToTmp(path)
if err != nil {
return fmt.Errorf("backup failed: %w", err)
}
}
} else if !os.IsNotExist(err) {
return fmt.Errorf("read existing file: %w", err)
}
dir := filepath.Dir(path)
tmp, err := os.CreateTemp(dir, ".tmp-*")
if err != nil {
return fmt.Errorf("create temp failed: %w", err)
}
tmpPath := tmp.Name()
if _, err := tmp.Write(data); err != nil {
_ = tmp.Close()
_ = os.Remove(tmpPath)
return fmt.Errorf("write failed: %w", err)
}
if err := tmp.Sync(); err != nil {
_ = tmp.Close()
_ = os.Remove(tmpPath)
return fmt.Errorf("sync failed: %w", err)
}
if err := tmp.Close(); err != nil {
_ = os.Remove(tmpPath)
return fmt.Errorf("close failed: %w", err)
}
if err := os.Rename(tmpPath, path); err != nil {
_ = os.Remove(tmpPath)
if backupPath != "" {
_ = copyFile(backupPath, path)
}
return fmt.Errorf("rename failed: %w", err)
}
return nil
}

502
cmd/config/files_test.go Normal file
View File

@@ -0,0 +1,502 @@
package config
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"runtime"
"testing"
)
func mustMarshal(t *testing.T, v any) []byte {
t.Helper()
data, err := json.MarshalIndent(v, "", " ")
if err != nil {
t.Fatal(err)
}
return data
}
func TestWriteWithBackup(t *testing.T) {
tmpDir := t.TempDir()
t.Run("creates file", func(t *testing.T) {
path := filepath.Join(tmpDir, "new.json")
data := mustMarshal(t, map[string]string{"key": "value"})
if err := writeWithBackup(path, data); err != nil {
t.Fatal(err)
}
content, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
var result map[string]string
if err := json.Unmarshal(content, &result); err != nil {
t.Fatal(err)
}
if result["key"] != "value" {
t.Errorf("expected value, got %s", result["key"])
}
})
t.Run("creates backup in /tmp/ollama-backups", func(t *testing.T) {
path := filepath.Join(tmpDir, "backup.json")
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
data := mustMarshal(t, map[string]bool{"updated": true})
if err := writeWithBackup(path, data); err != nil {
t.Fatal(err)
}
entries, err := os.ReadDir(backupDir())
if err != nil {
t.Fatal("backup directory not created")
}
var foundBackup bool
for _, entry := range entries {
if filepath.Ext(entry.Name()) != ".json" {
name := entry.Name()
if len(name) > len("backup.json.") && name[:len("backup.json.")] == "backup.json." {
backupPath := filepath.Join(backupDir(), name)
backup, err := os.ReadFile(backupPath)
if err == nil {
var backupData map[string]bool
json.Unmarshal(backup, &backupData)
if backupData["original"] {
foundBackup = true
os.Remove(backupPath)
break
}
}
}
}
}
if !foundBackup {
t.Error("backup file not created in /tmp/ollama-backups")
}
current, _ := os.ReadFile(path)
var currentData map[string]bool
json.Unmarshal(current, &currentData)
if !currentData["updated"] {
t.Error("file doesn't contain updated data")
}
})
t.Run("no backup for new file", func(t *testing.T) {
path := filepath.Join(tmpDir, "nobak.json")
data := mustMarshal(t, map[string]string{"new": "file"})
if err := writeWithBackup(path, data); err != nil {
t.Fatal(err)
}
entries, _ := os.ReadDir(backupDir())
for _, entry := range entries {
if len(entry.Name()) > len("nobak.json.") && entry.Name()[:len("nobak.json.")] == "nobak.json." {
t.Error("backup should not exist for new file")
}
}
})
t.Run("no backup when content unchanged", func(t *testing.T) {
path := filepath.Join(tmpDir, "unchanged.json")
data := mustMarshal(t, map[string]string{"key": "value"})
if err := writeWithBackup(path, data); err != nil {
t.Fatal(err)
}
entries1, _ := os.ReadDir(backupDir())
countBefore := 0
for _, e := range entries1 {
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
countBefore++
}
}
if err := writeWithBackup(path, data); err != nil {
t.Fatal(err)
}
entries2, _ := os.ReadDir(backupDir())
countAfter := 0
for _, e := range entries2 {
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
countAfter++
}
}
if countAfter != countBefore {
t.Errorf("backup was created when content unchanged (before=%d, after=%d)", countBefore, countAfter)
}
})
t.Run("backup filename contains unix timestamp", func(t *testing.T) {
path := filepath.Join(tmpDir, "timestamped.json")
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
data := mustMarshal(t, map[string]int{"v": 2})
if err := writeWithBackup(path, data); err != nil {
t.Fatal(err)
}
entries, _ := os.ReadDir(backupDir())
var found bool
for _, entry := range entries {
name := entry.Name()
if len(name) > len("timestamped.json.") && name[:len("timestamped.json.")] == "timestamped.json." {
timestamp := name[len("timestamped.json."):]
for _, c := range timestamp {
if c < '0' || c > '9' {
t.Errorf("backup filename timestamp contains non-numeric character: %s", name)
}
}
found = true
os.Remove(filepath.Join(backupDir(), name))
break
}
}
if !found {
t.Error("backup file with timestamp not found")
}
})
}
// Edge case tests for files.go
// TestWriteWithBackup_FailsIfBackupFails documents critical behavior: if backup fails, we must not proceed.
// User could lose their config with no way to recover.
func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("permission tests unreliable on Windows")
}
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "config.json")
// Create original file
originalContent := []byte(`{"original": true}`)
os.WriteFile(path, originalContent, 0o644)
// Make backup directory read-only to force backup failure
backupDir := backupDir()
os.MkdirAll(backupDir, 0o755)
os.Chmod(backupDir, 0o444) // Read-only
defer os.Chmod(backupDir, 0o755)
newContent := []byte(`{"updated": true}`)
err := writeWithBackup(path, newContent)
// Should fail because backup couldn't be created
if err == nil {
t.Error("expected error when backup fails, got nil")
}
// Original file should be preserved
current, _ := os.ReadFile(path)
if string(current) != string(originalContent) {
t.Errorf("original file was modified despite backup failure: got %s", string(current))
}
}
// TestWriteWithBackup_PermissionDenied verifies clear error when target file has wrong permissions.
// Common issue when config owned by root or wrong perms.
func TestWriteWithBackup_PermissionDenied(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("permission tests unreliable on Windows")
}
tmpDir := t.TempDir()
// Create a read-only directory
readOnlyDir := filepath.Join(tmpDir, "readonly")
os.MkdirAll(readOnlyDir, 0o755)
os.Chmod(readOnlyDir, 0o444)
defer os.Chmod(readOnlyDir, 0o755)
path := filepath.Join(readOnlyDir, "config.json")
err := writeWithBackup(path, []byte(`{"test": true}`))
if err == nil {
t.Error("expected permission error, got nil")
}
}
// TestWriteWithBackup_DirectoryDoesNotExist verifies behavior when target directory doesn't exist.
// writeWithBackup doesn't create directories - caller is responsible.
func TestWriteWithBackup_DirectoryDoesNotExist(t *testing.T) {
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "nonexistent", "subdir", "config.json")
err := writeWithBackup(path, []byte(`{"test": true}`))
// Should fail because directory doesn't exist
if err == nil {
t.Error("expected error for nonexistent directory, got nil")
}
}
// TestWriteWithBackup_SymlinkTarget documents behavior when target is a symlink.
// Documents what happens if user symlinks their config file.
func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("symlink tests may require admin on Windows")
}
tmpDir := t.TempDir()
realFile := filepath.Join(tmpDir, "real.json")
symlink := filepath.Join(tmpDir, "link.json")
// Create real file and symlink
os.WriteFile(realFile, []byte(`{"v": 1}`), 0o644)
os.Symlink(realFile, symlink)
// Write through symlink
err := writeWithBackup(symlink, []byte(`{"v": 2}`))
if err != nil {
t.Fatalf("writeWithBackup through symlink failed: %v", err)
}
// The real file should be updated (symlink followed for temp file creation)
content, _ := os.ReadFile(symlink)
if string(content) != `{"v": 2}` {
t.Errorf("symlink target not updated correctly: got %s", string(content))
}
}
// TestBackupToTmp_SpecialCharsInFilename verifies backup works with special characters.
// User may have config files with unusual names.
func TestBackupToTmp_SpecialCharsInFilename(t *testing.T) {
tmpDir := t.TempDir()
// File with spaces and special chars
path := filepath.Join(tmpDir, "my config (backup).json")
os.WriteFile(path, []byte(`{"test": true}`), 0o644)
backupPath, err := backupToTmp(path)
if err != nil {
t.Fatalf("backupToTmp with special chars failed: %v", err)
}
// Verify backup exists and has correct content
content, err := os.ReadFile(backupPath)
if err != nil {
t.Fatalf("could not read backup: %v", err)
}
if string(content) != `{"test": true}` {
t.Errorf("backup content mismatch: got %s", string(content))
}
os.Remove(backupPath)
}
// TestCopyFile_PreservesPermissions verifies that copyFile preserves file permissions.
func TestCopyFile_PreservesPermissions(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("permission preservation tests unreliable on Windows")
}
tmpDir := t.TempDir()
src := filepath.Join(tmpDir, "src.json")
dst := filepath.Join(tmpDir, "dst.json")
// Create source with specific permissions
os.WriteFile(src, []byte(`{"test": true}`), 0o600)
err := copyFile(src, dst)
if err != nil {
t.Fatalf("copyFile failed: %v", err)
}
srcInfo, _ := os.Stat(src)
dstInfo, _ := os.Stat(dst)
if srcInfo.Mode().Perm() != dstInfo.Mode().Perm() {
t.Errorf("permissions not preserved: src=%v, dst=%v", srcInfo.Mode().Perm(), dstInfo.Mode().Perm())
}
}
// TestCopyFile_SourceNotFound verifies clear error when source doesn't exist.
func TestCopyFile_SourceNotFound(t *testing.T) {
tmpDir := t.TempDir()
src := filepath.Join(tmpDir, "nonexistent.json")
dst := filepath.Join(tmpDir, "dst.json")
err := copyFile(src, dst)
if err == nil {
t.Error("expected error for nonexistent source, got nil")
}
}
// TestWriteWithBackup_TargetIsDirectory verifies error when path points to a directory.
func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
tmpDir := t.TempDir()
dirPath := filepath.Join(tmpDir, "actualdir")
os.MkdirAll(dirPath, 0o755)
err := writeWithBackup(dirPath, []byte(`{"test": true}`))
if err == nil {
t.Error("expected error when target is a directory, got nil")
}
}
// TestWriteWithBackup_EmptyData verifies writing zero bytes works correctly.
func TestWriteWithBackup_EmptyData(t *testing.T) {
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "empty.json")
err := writeWithBackup(path, []byte{})
if err != nil {
t.Fatalf("writeWithBackup with empty data failed: %v", err)
}
content, err := os.ReadFile(path)
if err != nil {
t.Fatalf("could not read file: %v", err)
}
if len(content) != 0 {
t.Errorf("expected empty file, got %d bytes", len(content))
}
}
// TestWriteWithBackup_FileUnreadableButDirWritable verifies behavior when existing file
// cannot be read (for backup comparison) but directory is writable.
func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("permission tests unreliable on Windows")
}
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "unreadable.json")
// Create file and make it unreadable
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
os.Chmod(path, 0o000)
defer os.Chmod(path, 0o644)
// Should fail because we can't read the file to compare/backup
err := writeWithBackup(path, []byte(`{"updated": true}`))
if err == nil {
t.Error("expected error when file is unreadable, got nil")
}
}
// TestWriteWithBackup_RapidSuccessiveWrites verifies backup works with multiple writes
// within the same second (timestamp collision scenario).
func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "rapid.json")
// Create initial file
os.WriteFile(path, []byte(`{"v": 0}`), 0o644)
// Rapid successive writes
for i := 1; i <= 3; i++ {
data := []byte(fmt.Sprintf(`{"v": %d}`, i))
if err := writeWithBackup(path, data); err != nil {
t.Fatalf("write %d failed: %v", i, err)
}
}
// Verify final content
content, _ := os.ReadFile(path)
if string(content) != `{"v": 3}` {
t.Errorf("expected final content {\"v\": 3}, got %s", string(content))
}
// Verify at least one backup exists
entries, _ := os.ReadDir(backupDir())
var backupCount int
for _, e := range entries {
if len(e.Name()) > len("rapid.json.") && e.Name()[:len("rapid.json.")] == "rapid.json." {
backupCount++
}
}
if backupCount == 0 {
t.Error("expected at least one backup file from rapid writes")
}
}
// TestWriteWithBackup_BackupDirIsFile verifies error when backup directory path is a file.
func TestWriteWithBackup_BackupDirIsFile(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("test modifies system temp directory")
}
// Create a file at the backup directory path
backupPath := backupDir()
// Clean up any existing directory first
os.RemoveAll(backupPath)
// Create a file instead of directory
os.WriteFile(backupPath, []byte("not a directory"), 0o644)
defer func() {
os.Remove(backupPath)
os.MkdirAll(backupPath, 0o755)
}()
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "test.json")
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
err := writeWithBackup(path, []byte(`{"updated": true}`))
if err == nil {
t.Error("expected error when backup dir is a file, got nil")
}
}
// TestWriteWithBackup_NoOrphanTempFiles verifies temp files are cleaned up on failure.
func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("permission tests unreliable on Windows")
}
tmpDir := t.TempDir()
// Count existing temp files
countTempFiles := func() int {
entries, _ := os.ReadDir(tmpDir)
count := 0
for _, e := range entries {
if len(e.Name()) > 4 && e.Name()[:4] == ".tmp" {
count++
}
}
return count
}
before := countTempFiles()
// Create a file, then make directory read-only to cause rename failure
path := filepath.Join(tmpDir, "orphan.json")
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
// Make a subdirectory and try to write there after making parent read-only
subDir := filepath.Join(tmpDir, "subdir")
os.MkdirAll(subDir, 0o755)
subPath := filepath.Join(subDir, "config.json")
os.WriteFile(subPath, []byte(`{"v": 1}`), 0o644)
// Make subdir read-only after creating temp file would succeed but rename would fail
// This is tricky to test - the temp file is created in the same dir, so if we can't
// rename, we also couldn't create. Let's just verify normal failure cleanup works.
// Force a failure by making the target a directory
badPath := filepath.Join(tmpDir, "isdir")
os.MkdirAll(badPath, 0o755)
_ = writeWithBackup(badPath, []byte(`{"test": true}`))
after := countTempFiles()
if after > before {
t.Errorf("orphan temp files left behind: before=%d, after=%d", before, after)
}
}

361
cmd/config/integrations.go Normal file
View File

@@ -0,0 +1,361 @@
package config
import (
"context"
"errors"
"fmt"
"maps"
"os"
"os/exec"
"runtime"
"slices"
"strings"
"time"
"github.com/ollama/ollama/api"
"github.com/spf13/cobra"
)
// Runners execute the launching of a model with the integration - claude, codex
// Editors can edit config files (supports multi-model selection) - opencode, droid
// They are composable interfaces where in some cases an editor is also a runner - opencode, droid
// Runner can run an integration with a model.
type Runner interface {
Run(model string) error
// String returns the human-readable name of the integration
String() string
}
// Editor can edit config files (supports multi-model selection)
type Editor interface {
// Paths returns the paths to the config files for the integration
Paths() []string
// Edit updates the config files for the integration with the given models
Edit(models []string) error
// Models returns the models currently configured for the integration
Models() []string
}
// integrations is the registry of available integrations.
var integrations = map[string]Runner{
"claude": &Claude{},
"codex": &Codex{},
"droid": &Droid{},
"opencode": &OpenCode{},
}
func selectIntegration() (string, error) {
if len(integrations) == 0 {
return "", fmt.Errorf("no integrations available")
}
names := slices.Sorted(maps.Keys(integrations))
var items []selectItem
for _, name := range names {
r := integrations[name]
description := r.String()
if conn, err := loadIntegration(name); err == nil && len(conn.Models) > 0 {
description = fmt.Sprintf("%s (%s)", r.String(), conn.Models[0])
}
items = append(items, selectItem{Name: name, Description: description})
}
return selectPrompt("Select integration:", items)
}
// selectModels lets the user select models for an integration
func selectModels(ctx context.Context, name, current string) ([]string, error) {
r, ok := integrations[name]
if !ok {
return nil, fmt.Errorf("unknown integration: %s", name)
}
client, err := api.ClientFromEnvironment()
if err != nil {
return nil, err
}
models, err := client.List(ctx)
if err != nil {
return nil, err
}
if len(models.Models) == 0 {
return nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
}
var items []selectItem
cloudModels := make(map[string]bool)
for _, m := range models.Models {
if m.RemoteModel != "" {
cloudModels[m.Name] = true
}
items = append(items, selectItem{Name: m.Name})
}
if len(items) == 0 {
return nil, fmt.Errorf("no local models available, run 'ollama pull <model>' first")
}
// Get previously configured models (saved config takes precedence)
var preChecked []string
if saved, err := loadIntegration(name); err == nil {
preChecked = saved.Models
} else if editor, ok := r.(Editor); ok {
preChecked = editor.Models()
}
checked := make(map[string]bool, len(preChecked))
for _, n := range preChecked {
checked[n] = true
}
// Resolve current to full name (e.g., "llama3.2" -> "llama3.2:latest")
for _, item := range items {
if item.Name == current || strings.HasPrefix(item.Name, current+":") {
current = item.Name
break
}
}
// If current model is configured, move to front of preChecked
if checked[current] {
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
}
// Sort: checked first, then alphabetical
slices.SortFunc(items, func(a, b selectItem) int {
ac, bc := checked[a.Name], checked[b.Name]
if ac != bc {
if ac {
return -1
}
return 1
}
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
})
var selected []string
// only editors support multi-model selection
if _, ok := r.(Editor); ok {
selected, err = multiSelectPrompt(fmt.Sprintf("Select models for %s:", r), items, preChecked)
if err != nil {
return nil, err
}
} else {
model, err := selectPrompt(fmt.Sprintf("Select model for %s:", r), items)
if err != nil {
return nil, err
}
selected = []string{model}
}
// if any model in selected is a cloud model, ensure signed in
var selectedCloudModels []string
for _, m := range selected {
if cloudModels[m] {
selectedCloudModels = append(selectedCloudModels, m)
}
}
if len(selectedCloudModels) > 0 {
// ensure user is signed in
user, err := client.Whoami(ctx)
if err == nil && user != nil && user.Name != "" {
return selected, nil
}
var aErr api.AuthorizationError
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
return nil, err
}
modelList := strings.Join(selectedCloudModels, ", ")
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
if err != nil || !yes {
return nil, fmt.Errorf("%s requires sign in", modelList)
}
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
// TODO(parthsareen): extract into auth package for cmd
// Auto-open browser (best effort, fail silently)
switch runtime.GOOS {
case "darwin":
_ = exec.Command("open", aErr.SigninURL).Start()
case "linux":
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
case "windows":
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
}
spinnerFrames := []string{"|", "/", "-", "\\"}
frame := 0
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "\r\033[K")
return nil, ctx.Err()
case <-ticker.C:
frame++
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
// poll every 10th frame (~2 seconds)
if frame%10 == 0 {
u, err := client.Whoami(ctx)
if err == nil && u != nil && u.Name != "" {
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
return selected, nil
}
}
}
}
}
return selected, nil
}
func runIntegration(name, modelName string) error {
r, ok := integrations[name]
if !ok {
return fmt.Errorf("unknown integration: %s", name)
}
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName)
return r.Run(modelName)
}
// ConfigCmd returns the cobra command for configuring integrations.
func ConfigCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error) *cobra.Command {
var modelFlag string
var launchFlag bool
cmd := &cobra.Command{
Use: "config [INTEGRATION]",
Short: "Configure an external integration to use Ollama",
Long: `Configure an external application to use Ollama models.
Supported integrations:
claude Claude Code
codex Codex
droid Droid
opencode OpenCode
Examples:
ollama config
ollama config claude
ollama config droid --launch`,
Args: cobra.MaximumNArgs(1),
PreRunE: checkServerHeartbeat,
RunE: func(cmd *cobra.Command, args []string) error {
var name string
if len(args) > 0 {
name = args[0]
} else {
var err error
name, err = selectIntegration()
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
}
r, ok := integrations[strings.ToLower(name)]
if !ok {
return fmt.Errorf("unknown integration: %s", name)
}
// If --launch without --model, use saved config if available
if launchFlag && modelFlag == "" {
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
return runIntegration(name, config.Models[0])
}
}
var models []string
if modelFlag != "" {
// When --model is specified, merge with existing models (new model becomes default)
models = []string{modelFlag}
if existing, err := loadIntegration(name); err == nil && len(existing.Models) > 0 {
for _, m := range existing.Models {
if m != modelFlag {
models = append(models, m)
}
}
}
} else {
var err error
models, err = selectModels(cmd.Context(), name, "")
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
}
if editor, isEditor := r.(Editor); isEditor {
paths := editor.Paths()
if len(paths) > 0 {
fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", r)
for _, p := range paths {
fmt.Fprintf(os.Stderr, " %s\n", p)
}
fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", backupDir())
if ok, _ := confirmPrompt("Proceed?"); !ok {
return nil
}
}
}
if err := saveIntegration(name, models); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
if editor, isEditor := r.(Editor); isEditor {
if err := editor.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
}
if _, isEditor := r.(Editor); isEditor {
if len(models) == 1 {
fmt.Fprintf(os.Stderr, "Added %s to %s\n", models[0], r)
} else {
fmt.Fprintf(os.Stderr, "Added %d models to %s (default: %s)\n", len(models), r, models[0])
}
}
if slices.ContainsFunc(models, func(m string) bool {
return !strings.HasSuffix(m, "cloud")
}) {
fmt.Fprintln(os.Stderr)
fmt.Fprintln(os.Stderr, "Coding agents work best with at least 64k context. Either:")
fmt.Fprintln(os.Stderr, " - Set the context slider in Ollama app settings")
fmt.Fprintln(os.Stderr, " - Run: OLLAMA_CONTEXT_LENGTH=64000 ollama serve")
}
if launchFlag {
return runIntegration(name, models[0])
}
if launch, _ := confirmPrompt(fmt.Sprintf("\nLaunch %s now?", r)); launch {
return runIntegration(name, models[0])
}
fmt.Fprintf(os.Stderr, "Run 'ollama config %s --launch' to start with %s\n", strings.ToLower(name), models[0])
return nil
},
}
cmd.Flags().StringVar(&modelFlag, "model", "", "Model to use")
cmd.Flags().BoolVar(&launchFlag, "launch", false, "Launch the integration after configuring")
return cmd
}

View File

@@ -0,0 +1,188 @@
package config
import (
"slices"
"strings"
"testing"
"github.com/spf13/cobra"
)
func TestIntegrationLookup(t *testing.T) {
tests := []struct {
name string
input string
wantFound bool
wantName string
}{
{"claude lowercase", "claude", true, "Claude Code"},
{"claude uppercase", "CLAUDE", true, "Claude Code"},
{"claude mixed case", "Claude", true, "Claude Code"},
{"codex", "codex", true, "Codex"},
{"droid", "droid", true, "Droid"},
{"opencode", "opencode", true, "OpenCode"},
{"unknown integration", "unknown", false, ""},
{"empty string", "", false, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r, found := integrations[strings.ToLower(tt.input)]
if found != tt.wantFound {
t.Errorf("integrations[%q] found = %v, want %v", tt.input, found, tt.wantFound)
}
if found && r.String() != tt.wantName {
t.Errorf("integrations[%q].String() = %q, want %q", tt.input, r.String(), tt.wantName)
}
})
}
}
func TestIntegrationRegistry(t *testing.T) {
expectedIntegrations := []string{"claude", "codex", "droid", "opencode"}
for _, name := range expectedIntegrations {
t.Run(name, func(t *testing.T) {
r, ok := integrations[name]
if !ok {
t.Fatalf("integration %q not found in registry", name)
}
if r.String() == "" {
t.Error("integration.String() should not be empty")
}
})
}
}
func TestHasLocalModel(t *testing.T) {
tests := []struct {
name string
models []string
want bool
}{
{"empty list", []string{}, false},
{"single local model", []string{"llama3.2"}, true},
{"single cloud model", []string{"cloud-model"}, false},
{"mixed models", []string{"cloud-model", "llama3.2"}, true},
{"multiple local models", []string{"llama3.2", "qwen2.5"}, true},
{"multiple cloud models", []string{"cloud-a", "cloud-b"}, false},
{"local model first", []string{"llama3.2", "cloud-model"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := slices.ContainsFunc(tt.models, func(m string) bool {
return !strings.Contains(m, "cloud")
})
if got != tt.want {
t.Errorf("hasLocalModel(%v) = %v, want %v", tt.models, got, tt.want)
}
})
}
}
func TestConfigCmd(t *testing.T) {
// Mock checkServerHeartbeat that always succeeds
mockCheck := func(cmd *cobra.Command, args []string) error {
return nil
}
cmd := ConfigCmd(mockCheck)
t.Run("command structure", func(t *testing.T) {
if cmd.Use != "config [INTEGRATION]" {
t.Errorf("Use = %q, want %q", cmd.Use, "config [INTEGRATION]")
}
if cmd.Short == "" {
t.Error("Short description should not be empty")
}
if cmd.Long == "" {
t.Error("Long description should not be empty")
}
})
t.Run("flags exist", func(t *testing.T) {
modelFlag := cmd.Flags().Lookup("model")
if modelFlag == nil {
t.Error("--model flag should exist")
}
launchFlag := cmd.Flags().Lookup("launch")
if launchFlag == nil {
t.Error("--launch flag should exist")
}
})
t.Run("PreRunE is set", func(t *testing.T) {
if cmd.PreRunE == nil {
t.Error("PreRunE should be set to checkServerHeartbeat")
}
})
}
func TestRunIntegration_UnknownIntegration(t *testing.T) {
err := runIntegration("unknown-integration", "model")
if err == nil {
t.Error("expected error for unknown integration, got nil")
}
if !strings.Contains(err.Error(), "unknown integration") {
t.Errorf("error should mention 'unknown integration', got: %v", err)
}
}
func TestHasLocalModel_DocumentsHeuristic(t *testing.T) {
tests := []struct {
name string
models []string
want bool
reason string
}{
{"empty list", []string{}, false, "empty list has no local models"},
{"contains-cloud-substring", []string{"deepseek-r1:cloud"}, false, "model with 'cloud' substring is considered cloud"},
{"cloud-in-name", []string{"my-cloud-model"}, false, "'cloud' anywhere in name = cloud model"},
{"cloudless", []string{"cloudless-model"}, false, "'cloudless' still contains 'cloud'"},
{"local-model", []string{"llama3.2"}, true, "no 'cloud' = local"},
{"mixed", []string{"cloud-model", "llama3.2"}, true, "one local model = hasLocalModel true"},
{"all-cloud", []string{"cloud-a", "cloud-b"}, false, "all contain 'cloud'"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := slices.ContainsFunc(tt.models, func(m string) bool {
return !strings.Contains(m, "cloud")
})
if got != tt.want {
t.Errorf("hasLocalModel(%v) = %v, want %v (%s)", tt.models, got, tt.want, tt.reason)
}
})
}
}
func TestConfigCmd_NilHeartbeat(t *testing.T) {
// This should not panic - cmd creation should work even with nil
cmd := ConfigCmd(nil)
if cmd == nil {
t.Fatal("ConfigCmd returned nil")
}
// PreRunE should be nil when passed nil
if cmd.PreRunE != nil {
t.Log("Note: PreRunE is set even when nil is passed (acceptable)")
}
}
func TestAllIntegrations_HaveRequiredMethods(t *testing.T) {
for name, r := range integrations {
t.Run(name, func(t *testing.T) {
// Test String() doesn't panic and returns non-empty
displayName := r.String()
if displayName == "" {
t.Error("String() should not return empty")
}
// Test Run() exists (we can't call it without actually running the command)
// Just verify the method is available
var _ func(string) error = r.Run
})
}
}

203
cmd/config/opencode.go Normal file
View File

@@ -0,0 +1,203 @@
package config
import (
"encoding/json"
"fmt"
"maps"
"os"
"os/exec"
"path/filepath"
"slices"
"strings"
)
// OpenCode implements Runner and Editor for OpenCode integration
type OpenCode struct{}
func (o *OpenCode) String() string { return "OpenCode" }
func (o *OpenCode) Run(model string) error {
if _, err := exec.LookPath("opencode"); err != nil {
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
}
// Call Edit() to ensure config is up-to-date before launch
models := []string{model}
if config, err := loadIntegration("opencode"); err == nil && len(config.Models) > 0 {
models = config.Models
}
if err := o.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("opencode")
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func (o *OpenCode) Paths() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
var paths []string
p := filepath.Join(home, ".config", "opencode", "opencode.json")
if _, err := os.Stat(p); err == nil {
paths = append(paths, p)
}
sp := filepath.Join(home, ".local", "state", "opencode", "model.json")
if _, err := os.Stat(sp); err == nil {
paths = append(paths, sp)
}
return paths
}
func (o *OpenCode) Edit(modelList []string) error {
if len(modelList) == 0 {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
configPath := filepath.Join(home, ".config", "opencode", "opencode.json")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
_ = json.Unmarshal(data, &config) // Ignore parse errors; treat missing/corrupt files as empty
}
config["$schema"] = "https://opencode.ai/config.json"
provider, ok := config["provider"].(map[string]any)
if !ok {
provider = make(map[string]any)
}
ollama, ok := provider["ollama"].(map[string]any)
if !ok {
ollama = map[string]any{
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama (local)",
"options": map[string]any{
"baseURL": "http://localhost:11434/v1",
},
}
}
models, ok := ollama["models"].(map[string]any)
if !ok {
models = make(map[string]any)
}
selectedSet := make(map[string]bool)
for _, m := range modelList {
selectedSet[m] = true
}
for name, cfg := range models {
if cfgMap, ok := cfg.(map[string]any); ok {
if displayName, ok := cfgMap["name"].(string); ok {
if strings.HasSuffix(displayName, "[Ollama]") && !selectedSet[name] {
delete(models, name)
}
}
}
}
for _, model := range modelList {
models[model] = map[string]any{
"name": fmt.Sprintf("%s [Ollama]", model),
}
}
ollama["models"] = models
provider["ollama"] = ollama
config["provider"] = provider
configData, err := json.MarshalIndent(config, "", " ")
if err != nil {
return err
}
if err := writeWithBackup(configPath, configData); err != nil {
return err
}
statePath := filepath.Join(home, ".local", "state", "opencode", "model.json")
if err := os.MkdirAll(filepath.Dir(statePath), 0o755); err != nil {
return err
}
state := map[string]any{
"recent": []any{},
"favorite": []any{},
"variant": map[string]any{},
}
if data, err := os.ReadFile(statePath); err == nil {
_ = json.Unmarshal(data, &state) // Ignore parse errors; use defaults
}
recent, _ := state["recent"].([]any)
modelSet := make(map[string]bool)
for _, m := range modelList {
modelSet[m] = true
}
// Filter out existing Ollama models we're about to re-add
newRecent := slices.DeleteFunc(slices.Clone(recent), func(entry any) bool {
e, ok := entry.(map[string]any)
if !ok || e["providerID"] != "ollama" {
return false
}
modelID, _ := e["modelID"].(string)
return modelSet[modelID]
})
// Prepend models in reverse order so first model ends up first
for _, model := range slices.Backward(modelList) {
newRecent = slices.Insert(newRecent, 0, any(map[string]any{
"providerID": "ollama",
"modelID": model,
}))
}
const maxRecentModels = 10
newRecent = newRecent[:min(len(newRecent), maxRecentModels)]
state["recent"] = newRecent
stateData, err := json.MarshalIndent(state, "", " ")
if err != nil {
return err
}
return writeWithBackup(statePath, stateData)
}
func (o *OpenCode) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
config, err := readJSONFile(filepath.Join(home, ".config", "opencode", "opencode.json"))
if err != nil {
return nil
}
provider, _ := config["provider"].(map[string]any)
ollama, _ := provider["ollama"].(map[string]any)
models, _ := ollama["models"].(map[string]any)
if len(models) == 0 {
return nil
}
keys := slices.Collect(maps.Keys(models))
slices.Sort(keys)
return keys
}

437
cmd/config/opencode_test.go Normal file
View File

@@ -0,0 +1,437 @@
package config
import (
"encoding/json"
"os"
"path/filepath"
"testing"
)
func TestOpenCodeIntegration(t *testing.T) {
o := &OpenCode{}
t.Run("String", func(t *testing.T) {
if got := o.String(); got != "OpenCode" {
t.Errorf("String() = %q, want %q", got, "OpenCode")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = o
})
t.Run("implements Editor", func(t *testing.T) {
var _ Editor = o
})
}
func TestOpenCodeEdit(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
statePath := filepath.Join(stateDir, "model.json")
cleanup := func() {
os.RemoveAll(configDir)
os.RemoveAll(stateDir)
}
t.Run("fresh install", func(t *testing.T) {
cleanup()
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
assertOpenCodeModelExists(t, configPath, "llama3.2")
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
})
t.Run("preserve other providers", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"provider":{"anthropic":{"apiKey":"xxx"}}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
provider := cfg["provider"].(map[string]any)
if provider["anthropic"] == nil {
t.Error("anthropic provider was removed")
}
assertOpenCodeModelExists(t, configPath, "llama3.2")
})
t.Run("preserve other models", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"mistral":{"name":"Mistral"}}}}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
assertOpenCodeModelExists(t, configPath, "mistral")
assertOpenCodeModelExists(t, configPath, "llama3.2")
})
t.Run("update existing model", func(t *testing.T) {
cleanup()
o.Edit([]string{"llama3.2"})
o.Edit([]string{"llama3.2"})
assertOpenCodeModelExists(t, configPath, "llama3.2")
})
t.Run("preserve top-level keys", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"theme":"dark","keybindings":{}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
if cfg["theme"] != "dark" {
t.Error("theme was removed")
}
if cfg["keybindings"] == nil {
t.Error("keybindings was removed")
}
})
t.Run("model state - insert at index 0", func(t *testing.T) {
cleanup()
os.MkdirAll(stateDir, 0o755)
os.WriteFile(statePath, []byte(`{"recent":[{"providerID":"anthropic","modelID":"claude"}],"favorite":[],"variant":{}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
assertOpenCodeRecentModel(t, statePath, 1, "anthropic", "claude")
})
t.Run("model state - preserve favorites and variants", func(t *testing.T) {
cleanup()
os.MkdirAll(stateDir, 0o755)
os.WriteFile(statePath, []byte(`{"recent":[],"favorite":[{"providerID":"x","modelID":"y"}],"variant":{"a":"b"}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(statePath)
var state map[string]any
json.Unmarshal(data, &state)
if len(state["favorite"].([]any)) != 1 {
t.Error("favorite was modified")
}
if state["variant"].(map[string]any)["a"] != "b" {
t.Error("variant was modified")
}
})
t.Run("model state - deduplicate on re-add", func(t *testing.T) {
cleanup()
os.MkdirAll(stateDir, 0o755)
os.WriteFile(statePath, []byte(`{"recent":[{"providerID":"ollama","modelID":"llama3.2"},{"providerID":"anthropic","modelID":"claude"}],"favorite":[],"variant":{}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(statePath)
var state map[string]any
json.Unmarshal(data, &state)
recent := state["recent"].([]any)
if len(recent) != 2 {
t.Errorf("expected 2 recent entries, got %d", len(recent))
}
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
})
t.Run("remove model", func(t *testing.T) {
cleanup()
// First add two models
o.Edit([]string{"llama3.2", "mistral"})
assertOpenCodeModelExists(t, configPath, "llama3.2")
assertOpenCodeModelExists(t, configPath, "mistral")
// Then remove one by only selecting the other
o.Edit([]string{"llama3.2"})
assertOpenCodeModelExists(t, configPath, "llama3.2")
assertOpenCodeModelNotExists(t, configPath, "mistral")
})
t.Run("remove model preserves non-ollama models", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// Add a non-Ollama model manually
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"external":{"name":"External Model"}}}}}`), 0o644)
o.Edit([]string{"llama3.2"})
assertOpenCodeModelExists(t, configPath, "llama3.2")
assertOpenCodeModelExists(t, configPath, "external") // Should be preserved
})
}
func assertOpenCodeModelExists(t *testing.T, path, model string) {
t.Helper()
data, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatal(err)
}
provider, ok := cfg["provider"].(map[string]any)
if !ok {
t.Fatal("provider not found")
}
ollama, ok := provider["ollama"].(map[string]any)
if !ok {
t.Fatal("ollama provider not found")
}
models, ok := ollama["models"].(map[string]any)
if !ok {
t.Fatal("models not found")
}
if models[model] == nil {
t.Errorf("model %s not found", model)
}
}
func assertOpenCodeModelNotExists(t *testing.T, path, model string) {
t.Helper()
data, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatal(err)
}
provider, ok := cfg["provider"].(map[string]any)
if !ok {
return // No provider means no model
}
ollama, ok := provider["ollama"].(map[string]any)
if !ok {
return // No ollama means no model
}
models, ok := ollama["models"].(map[string]any)
if !ok {
return // No models means no model
}
if models[model] != nil {
t.Errorf("model %s should not exist but was found", model)
}
}
func assertOpenCodeRecentModel(t *testing.T, path string, index int, providerID, modelID string) {
t.Helper()
data, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
var state map[string]any
if err := json.Unmarshal(data, &state); err != nil {
t.Fatal(err)
}
recent, ok := state["recent"].([]any)
if !ok {
t.Fatal("recent not found")
}
if index >= len(recent) {
t.Fatalf("index %d out of range (len=%d)", index, len(recent))
}
entry, ok := recent[index].(map[string]any)
if !ok {
t.Fatal("entry is not a map")
}
if entry["providerID"] != providerID {
t.Errorf("expected providerID %s, got %s", providerID, entry["providerID"])
}
if entry["modelID"] != modelID {
t.Errorf("expected modelID %s, got %s", modelID, entry["modelID"])
}
}
// Edge case tests for opencode.go
func TestOpenCodeEdit_CorruptedConfigJSON(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{corrupted json content`), 0o644)
// Should not panic - corrupted JSON should be treated as empty
err := o.Edit([]string{"llama3.2"})
if err != nil {
t.Fatalf("Edit failed with corrupted config: %v", err)
}
// Verify valid JSON was created
data, _ := os.ReadFile(configPath)
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Errorf("resulting config is not valid JSON: %v", err)
}
}
func TestOpenCodeEdit_CorruptedStateJSON(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
statePath := filepath.Join(stateDir, "model.json")
os.MkdirAll(stateDir, 0o755)
os.WriteFile(statePath, []byte(`{corrupted state`), 0o644)
err := o.Edit([]string{"llama3.2"})
if err != nil {
t.Fatalf("Edit failed with corrupted state: %v", err)
}
// Verify valid state was created
data, _ := os.ReadFile(statePath)
var state map[string]any
if err := json.Unmarshal(data, &state); err != nil {
t.Errorf("resulting state is not valid JSON: %v", err)
}
}
func TestOpenCodeEdit_WrongTypeProvider(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"provider": "not a map"}`), 0o644)
err := o.Edit([]string{"llama3.2"})
if err != nil {
t.Fatalf("Edit with wrong type provider failed: %v", err)
}
// Verify provider is now correct type
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
provider, ok := cfg["provider"].(map[string]any)
if !ok {
t.Fatalf("provider should be map after setup, got %T", cfg["provider"])
}
if provider["ollama"] == nil {
t.Error("ollama provider should be created")
}
}
func TestOpenCodeEdit_WrongTypeRecent(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
statePath := filepath.Join(stateDir, "model.json")
os.MkdirAll(stateDir, 0o755)
os.WriteFile(statePath, []byte(`{"recent": "not an array", "favorite": [], "variant": {}}`), 0o644)
err := o.Edit([]string{"llama3.2"})
if err != nil {
t.Fatalf("Edit with wrong type recent failed: %v", err)
}
// The function should handle this gracefully
data, _ := os.ReadFile(statePath)
var state map[string]any
json.Unmarshal(data, &state)
// recent should be properly set after setup
recent, ok := state["recent"].([]any)
if !ok {
t.Logf("Note: recent type after setup is %T (documenting behavior)", state["recent"])
} else if len(recent) == 0 {
t.Logf("Note: recent is empty (documenting behavior)")
}
}
func TestOpenCodeEdit_EmptyModels(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
os.MkdirAll(configDir, 0o755)
originalContent := `{"provider":{"ollama":{"models":{"existing":{}}}}}`
os.WriteFile(configPath, []byte(originalContent), 0o644)
// Empty models should be no-op
err := o.Edit([]string{})
if err != nil {
t.Fatalf("Edit with empty models failed: %v", err)
}
// Original content should be preserved (file not modified)
data, _ := os.ReadFile(configPath)
if string(data) != originalContent {
t.Errorf("empty models should not modify file, but content changed")
}
}
func TestOpenCodeEdit_SpecialCharsInModelName(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Model name with special characters (though unusual)
specialModel := `model-with-"quotes"`
err := o.Edit([]string{specialModel})
if err != nil {
t.Fatalf("Edit with special chars failed: %v", err)
}
// Verify it was stored correctly
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
data, _ := os.ReadFile(configPath)
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatalf("resulting config is invalid JSON: %v", err)
}
// Model should be accessible
provider, _ := cfg["provider"].(map[string]any)
ollama, _ := provider["ollama"].(map[string]any)
models, _ := ollama["models"].(map[string]any)
if models[specialModel] == nil {
t.Errorf("model with special chars not found in config")
}
}
func TestOpenCodeModels_NoConfig(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
models := o.Models()
if len(models) > 0 {
t.Errorf("expected nil/empty for missing config, got %v", models)
}
}

499
cmd/config/selector.go Normal file
View File

@@ -0,0 +1,499 @@
package config
import (
"errors"
"fmt"
"io"
"os"
"strings"
"golang.org/x/term"
)
// ANSI escape sequences for terminal formatting.
const (
ansiHideCursor = "\033[?25l"
ansiShowCursor = "\033[?25h"
ansiBold = "\033[1m"
ansiReset = "\033[0m"
ansiGray = "\033[37m"
ansiClearDown = "\033[J"
)
const maxDisplayedItems = 10
var errCancelled = errors.New("cancelled")
type selectItem struct {
Name string
Description string
}
type inputEvent int
const (
eventNone inputEvent = iota
eventEnter
eventEscape
eventUp
eventDown
eventTab
eventBackspace
eventChar
)
type selectState struct {
items []selectItem
filter string
selected int
scrollOffset int
}
func newSelectState(items []selectItem) *selectState {
return &selectState{items: items}
}
func (s *selectState) filtered() []selectItem {
return filterItems(s.items, s.filter)
}
func (s *selectState) handleInput(event inputEvent, char byte) (done bool, result string, err error) {
filtered := s.filtered()
switch event {
case eventEnter:
if len(filtered) > 0 && s.selected < len(filtered) {
return true, filtered[s.selected].Name, nil
}
case eventEscape:
return true, "", errCancelled
case eventBackspace:
if len(s.filter) > 0 {
s.filter = s.filter[:len(s.filter)-1]
s.selected = 0
s.scrollOffset = 0
}
case eventUp:
if s.selected > 0 {
s.selected--
if s.selected < s.scrollOffset {
s.scrollOffset = s.selected
}
}
case eventDown:
if s.selected < len(filtered)-1 {
s.selected++
if s.selected >= s.scrollOffset+maxDisplayedItems {
s.scrollOffset = s.selected - maxDisplayedItems + 1
}
}
case eventChar:
s.filter += string(char)
s.selected = 0
s.scrollOffset = 0
}
return false, "", nil
}
type multiSelectState struct {
items []selectItem
itemIndex map[string]int
filter string
highlighted int
scrollOffset int
checked map[int]bool
checkOrder []int
focusOnButton bool
}
func newMultiSelectState(items []selectItem, preChecked []string) *multiSelectState {
s := &multiSelectState{
items: items,
itemIndex: make(map[string]int, len(items)),
checked: make(map[int]bool),
}
for i, item := range items {
s.itemIndex[item.Name] = i
}
for _, name := range preChecked {
if idx, ok := s.itemIndex[name]; ok {
s.checked[idx] = true
s.checkOrder = append(s.checkOrder, idx)
}
}
return s
}
func (s *multiSelectState) filtered() []selectItem {
return filterItems(s.items, s.filter)
}
func (s *multiSelectState) toggleItem() {
filtered := s.filtered()
if len(filtered) == 0 || s.highlighted >= len(filtered) {
return
}
item := filtered[s.highlighted]
origIdx := s.itemIndex[item.Name]
if s.checked[origIdx] {
delete(s.checked, origIdx)
for i, idx := range s.checkOrder {
if idx == origIdx {
s.checkOrder = append(s.checkOrder[:i], s.checkOrder[i+1:]...)
break
}
}
} else {
s.checked[origIdx] = true
s.checkOrder = append(s.checkOrder, origIdx)
}
}
func (s *multiSelectState) handleInput(event inputEvent, char byte) (done bool, result []string, err error) {
filtered := s.filtered()
switch event {
case eventEnter:
if s.focusOnButton && len(s.checkOrder) > 0 {
var res []string
for _, idx := range s.checkOrder {
res = append(res, s.items[idx].Name)
}
return true, res, nil
} else if !s.focusOnButton {
s.toggleItem()
}
case eventTab:
if len(s.checkOrder) > 0 {
s.focusOnButton = !s.focusOnButton
}
case eventEscape:
return true, nil, errCancelled
case eventBackspace:
if len(s.filter) > 0 {
s.filter = s.filter[:len(s.filter)-1]
s.highlighted = 0
s.scrollOffset = 0
s.focusOnButton = false
}
case eventUp:
if s.focusOnButton {
s.focusOnButton = false
} else if s.highlighted > 0 {
s.highlighted--
if s.highlighted < s.scrollOffset {
s.scrollOffset = s.highlighted
}
}
case eventDown:
if s.focusOnButton {
s.focusOnButton = false
} else if s.highlighted < len(filtered)-1 {
s.highlighted++
if s.highlighted >= s.scrollOffset+maxDisplayedItems {
s.scrollOffset = s.highlighted - maxDisplayedItems + 1
}
}
case eventChar:
s.filter += string(char)
s.highlighted = 0
s.scrollOffset = 0
s.focusOnButton = false
}
return false, nil, nil
}
func (s *multiSelectState) selectedCount() int {
return len(s.checkOrder)
}
// Terminal I/O handling
type terminalState struct {
fd int
oldState *term.State
}
func enterRawMode() (*terminalState, error) {
fd := int(os.Stdin.Fd())
oldState, err := term.MakeRaw(fd)
if err != nil {
return nil, err
}
fmt.Fprint(os.Stderr, ansiHideCursor)
return &terminalState{fd: fd, oldState: oldState}, nil
}
func (t *terminalState) restore() {
fmt.Fprint(os.Stderr, ansiShowCursor)
term.Restore(t.fd, t.oldState)
}
func clearLines(n int) {
if n > 0 {
fmt.Fprintf(os.Stderr, "\033[%dA", n)
fmt.Fprint(os.Stderr, ansiClearDown)
}
}
func parseInput(r io.Reader) (inputEvent, byte, error) {
buf := make([]byte, 3)
n, err := r.Read(buf)
if err != nil {
return 0, 0, err
}
switch {
case n == 1 && buf[0] == 13:
return eventEnter, 0, nil
case n == 1 && (buf[0] == 3 || buf[0] == 27):
return eventEscape, 0, nil
case n == 1 && buf[0] == 9:
return eventTab, 0, nil
case n == 1 && buf[0] == 127:
return eventBackspace, 0, nil
case n == 3 && buf[0] == 27 && buf[1] == 91 && buf[2] == 65:
return eventUp, 0, nil
case n == 3 && buf[0] == 27 && buf[1] == 91 && buf[2] == 66:
return eventDown, 0, nil
case n == 1 && buf[0] >= 32 && buf[0] < 127:
return eventChar, buf[0], nil
}
return eventNone, 0, nil
}
// Rendering
func renderSelect(w io.Writer, prompt string, s *selectState) int {
filtered := s.filtered()
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
lineCount := 1
if len(filtered) == 0 {
fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset)
lineCount++
} else {
displayCount := min(len(filtered), maxDisplayedItems)
for i := range displayCount {
idx := s.scrollOffset + i
if idx >= len(filtered) {
break
}
item := filtered[idx]
prefix := " "
if idx == s.selected {
prefix = " " + ansiBold + "> "
}
if item.Description != "" {
fmt.Fprintf(w, "%s%s%s %s- %s%s\r\n", prefix, item.Name, ansiReset, ansiGray, item.Description, ansiReset)
} else {
fmt.Fprintf(w, "%s%s%s\r\n", prefix, item.Name, ansiReset)
}
lineCount++
}
if remaining := len(filtered) - s.scrollOffset - displayCount; remaining > 0 {
fmt.Fprintf(w, " %s... and %d more%s\r\n", ansiGray, remaining, ansiReset)
lineCount++
}
}
return lineCount
}
func renderMultiSelect(w io.Writer, prompt string, s *multiSelectState) int {
filtered := s.filtered()
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
lineCount := 1
if len(filtered) == 0 {
fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset)
lineCount++
} else {
displayCount := min(len(filtered), maxDisplayedItems)
for i := range displayCount {
idx := s.scrollOffset + i
if idx >= len(filtered) {
break
}
item := filtered[idx]
origIdx := s.itemIndex[item.Name]
checkbox := "[ ]"
if s.checked[origIdx] {
checkbox = "[x]"
}
prefix := " "
suffix := ""
if idx == s.highlighted && !s.focusOnButton {
prefix = "> "
}
if len(s.checkOrder) > 0 && s.checkOrder[0] == origIdx {
suffix = " " + ansiGray + "(default)" + ansiReset
}
if idx == s.highlighted && !s.focusOnButton {
fmt.Fprintf(w, " %s%s %s %s%s%s\r\n", ansiBold, prefix, checkbox, item.Name, ansiReset, suffix)
} else {
fmt.Fprintf(w, " %s %s %s%s\r\n", prefix, checkbox, item.Name, suffix)
}
lineCount++
}
if remaining := len(filtered) - s.scrollOffset - displayCount; remaining > 0 {
fmt.Fprintf(w, " %s... and %d more%s\r\n", ansiGray, remaining, ansiReset)
lineCount++
}
}
fmt.Fprintf(w, "\r\n")
lineCount++
count := s.selectedCount()
switch {
case count == 0:
fmt.Fprintf(w, " %sSelect at least one model.%s\r\n", ansiGray, ansiReset)
case s.focusOnButton:
fmt.Fprintf(w, " %s> [ Continue ]%s %s(%d selected)%s\r\n", ansiBold, ansiReset, ansiGray, count, ansiReset)
default:
fmt.Fprintf(w, " %s[ Continue ] (%d selected) - press Tab%s\r\n", ansiGray, count, ansiReset)
}
lineCount++
return lineCount
}
// selectPrompt prompts the user to select a single item from a list.
func selectPrompt(prompt string, items []selectItem) (string, error) {
if len(items) == 0 {
return "", fmt.Errorf("no items to select from")
}
ts, err := enterRawMode()
if err != nil {
return "", err
}
defer ts.restore()
state := newSelectState(items)
var lastLineCount int
render := func() {
clearLines(lastLineCount)
lastLineCount = renderSelect(os.Stderr, prompt, state)
}
render()
for {
event, char, err := parseInput(os.Stdin)
if err != nil {
return "", err
}
done, result, err := state.handleInput(event, char)
if done {
clearLines(lastLineCount)
if err != nil {
return "", err
}
return result, nil
}
render()
}
}
// multiSelectPrompt prompts the user to select multiple items from a list.
func multiSelectPrompt(prompt string, items []selectItem, preChecked []string) ([]string, error) {
if len(items) == 0 {
return nil, fmt.Errorf("no items to select from")
}
ts, err := enterRawMode()
if err != nil {
return nil, err
}
defer ts.restore()
state := newMultiSelectState(items, preChecked)
var lastLineCount int
render := func() {
clearLines(lastLineCount)
lastLineCount = renderMultiSelect(os.Stderr, prompt, state)
}
render()
for {
event, char, err := parseInput(os.Stdin)
if err != nil {
return nil, err
}
done, result, err := state.handleInput(event, char)
if done {
clearLines(lastLineCount)
if err != nil {
return nil, err
}
return result, nil
}
render()
}
}
func confirmPrompt(prompt string) (bool, error) {
fd := int(os.Stdin.Fd())
oldState, err := term.MakeRaw(fd)
if err != nil {
return false, err
}
defer term.Restore(fd, oldState)
fmt.Fprintf(os.Stderr, "%s [y/n] ", prompt)
buf := make([]byte, 1)
for {
if _, err := os.Stdin.Read(buf); err != nil {
return false, err
}
switch buf[0] {
case 'Y', 'y', 13:
fmt.Fprintf(os.Stderr, "yes\r\n")
return true, nil
case 'N', 'n', 27, 3:
fmt.Fprintf(os.Stderr, "no\r\n")
return false, nil
}
}
}
func filterItems(items []selectItem, filter string) []selectItem {
if filter == "" {
return items
}
var result []selectItem
filterLower := strings.ToLower(filter)
for _, item := range items {
if strings.Contains(strings.ToLower(item.Name), filterLower) {
result = append(result, item)
}
}
return result
}

913
cmd/config/selector_test.go Normal file
View File

@@ -0,0 +1,913 @@
package config
import (
"bytes"
"strings"
"testing"
)
func TestFilterItems(t *testing.T) {
items := []selectItem{
{Name: "llama3.2:latest"},
{Name: "qwen2.5:7b"},
{Name: "deepseek-v3:cloud"},
{Name: "GPT-OSS:20b"},
}
t.Run("EmptyFilter_ReturnsAllItems", func(t *testing.T) {
result := filterItems(items, "")
if len(result) != len(items) {
t.Errorf("expected %d items, got %d", len(items), len(result))
}
})
t.Run("CaseInsensitive_UppercaseFilterMatchesLowercase", func(t *testing.T) {
result := filterItems(items, "LLAMA")
if len(result) != 1 || result[0].Name != "llama3.2:latest" {
t.Errorf("expected llama3.2:latest, got %v", result)
}
})
t.Run("CaseInsensitive_LowercaseFilterMatchesUppercase", func(t *testing.T) {
result := filterItems(items, "gpt")
if len(result) != 1 || result[0].Name != "GPT-OSS:20b" {
t.Errorf("expected GPT-OSS:20b, got %v", result)
}
})
t.Run("PartialMatch", func(t *testing.T) {
result := filterItems(items, "deep")
if len(result) != 1 || result[0].Name != "deepseek-v3:cloud" {
t.Errorf("expected deepseek-v3:cloud, got %v", result)
}
})
t.Run("NoMatch_ReturnsEmpty", func(t *testing.T) {
result := filterItems(items, "nonexistent")
if len(result) != 0 {
t.Errorf("expected 0 items, got %d", len(result))
}
})
}
func TestSelectState(t *testing.T) {
items := []selectItem{
{Name: "item1"},
{Name: "item2"},
{Name: "item3"},
}
t.Run("InitialState", func(t *testing.T) {
s := newSelectState(items)
if s.selected != 0 {
t.Errorf("expected selected=0, got %d", s.selected)
}
if s.filter != "" {
t.Errorf("expected empty filter, got %q", s.filter)
}
if s.scrollOffset != 0 {
t.Errorf("expected scrollOffset=0, got %d", s.scrollOffset)
}
})
t.Run("Enter_SelectsCurrentItem", func(t *testing.T) {
s := newSelectState(items)
done, result, err := s.handleInput(eventEnter, 0)
if !done || result != "item1" || err != nil {
t.Errorf("expected (true, item1, nil), got (%v, %v, %v)", done, result, err)
}
})
t.Run("Enter_WithFilter_SelectsFilteredItem", func(t *testing.T) {
s := newSelectState(items)
s.filter = "item3"
done, result, err := s.handleInput(eventEnter, 0)
if !done || result != "item3" || err != nil {
t.Errorf("expected (true, item3, nil), got (%v, %v, %v)", done, result, err)
}
})
t.Run("Enter_EmptyFilteredList_DoesNothing", func(t *testing.T) {
s := newSelectState(items)
s.filter = "nonexistent"
done, result, err := s.handleInput(eventEnter, 0)
if done || result != "" || err != nil {
t.Errorf("expected (false, '', nil), got (%v, %v, %v)", done, result, err)
}
})
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
s := newSelectState(items)
done, result, err := s.handleInput(eventEscape, 0)
if !done || result != "" || err != errCancelled {
t.Errorf("expected (true, '', errCancelled), got (%v, %v, %v)", done, result, err)
}
})
t.Run("Down_MovesSelection", func(t *testing.T) {
s := newSelectState(items)
s.handleInput(eventDown, 0)
if s.selected != 1 {
t.Errorf("expected selected=1, got %d", s.selected)
}
})
t.Run("Down_AtBottom_StaysAtBottom", func(t *testing.T) {
s := newSelectState(items)
s.selected = 2
s.handleInput(eventDown, 0)
if s.selected != 2 {
t.Errorf("expected selected=2 (stayed at bottom), got %d", s.selected)
}
})
t.Run("Up_MovesSelection", func(t *testing.T) {
s := newSelectState(items)
s.selected = 2
s.handleInput(eventUp, 0)
if s.selected != 1 {
t.Errorf("expected selected=1, got %d", s.selected)
}
})
t.Run("Up_AtTop_StaysAtTop", func(t *testing.T) {
s := newSelectState(items)
s.handleInput(eventUp, 0)
if s.selected != 0 {
t.Errorf("expected selected=0 (stayed at top), got %d", s.selected)
}
})
t.Run("Char_AppendsToFilter", func(t *testing.T) {
s := newSelectState(items)
s.handleInput(eventChar, 'i')
s.handleInput(eventChar, 't')
s.handleInput(eventChar, 'e')
s.handleInput(eventChar, 'm')
s.handleInput(eventChar, '2')
if s.filter != "item2" {
t.Errorf("expected filter='item2', got %q", s.filter)
}
filtered := s.filtered()
if len(filtered) != 1 || filtered[0].Name != "item2" {
t.Errorf("expected [item2], got %v", filtered)
}
})
t.Run("Char_ResetsSelectionToZero", func(t *testing.T) {
s := newSelectState(items)
s.selected = 2
s.handleInput(eventChar, 'x')
if s.selected != 0 {
t.Errorf("expected selected=0 after typing, got %d", s.selected)
}
})
t.Run("Backspace_RemovesLastFilterChar", func(t *testing.T) {
s := newSelectState(items)
s.filter = "test"
s.handleInput(eventBackspace, 0)
if s.filter != "tes" {
t.Errorf("expected filter='tes', got %q", s.filter)
}
})
t.Run("Backspace_EmptyFilter_DoesNothing", func(t *testing.T) {
s := newSelectState(items)
s.handleInput(eventBackspace, 0)
if s.filter != "" {
t.Errorf("expected filter='', got %q", s.filter)
}
})
t.Run("Backspace_ResetsSelectionToZero", func(t *testing.T) {
s := newSelectState(items)
s.filter = "test"
s.selected = 2
s.handleInput(eventBackspace, 0)
if s.selected != 0 {
t.Errorf("expected selected=0 after backspace, got %d", s.selected)
}
})
t.Run("Scroll_DownPastVisibleItems_ScrollsViewport", func(t *testing.T) {
// maxDisplayedItems is 10, so with 15 items we need to scroll
manyItems := make([]selectItem, 15)
for i := range manyItems {
manyItems[i] = selectItem{Name: string(rune('a' + i))}
}
s := newSelectState(manyItems)
// move down 12 times (past the 10-item viewport)
for range 12 {
s.handleInput(eventDown, 0)
}
if s.selected != 12 {
t.Errorf("expected selected=12, got %d", s.selected)
}
if s.scrollOffset != 3 {
t.Errorf("expected scrollOffset=3 (12-10+1), got %d", s.scrollOffset)
}
})
t.Run("Scroll_UpPastScrollOffset_ScrollsViewport", func(t *testing.T) {
manyItems := make([]selectItem, 15)
for i := range manyItems {
manyItems[i] = selectItem{Name: string(rune('a' + i))}
}
s := newSelectState(manyItems)
s.selected = 5
s.scrollOffset = 5
s.handleInput(eventUp, 0)
if s.selected != 4 {
t.Errorf("expected selected=4, got %d", s.selected)
}
if s.scrollOffset != 4 {
t.Errorf("expected scrollOffset=4, got %d", s.scrollOffset)
}
})
}
func TestMultiSelectState(t *testing.T) {
items := []selectItem{
{Name: "item1"},
{Name: "item2"},
{Name: "item3"},
}
t.Run("InitialState_NoPrechecked", func(t *testing.T) {
s := newMultiSelectState(items, nil)
if s.highlighted != 0 {
t.Errorf("expected highlighted=0, got %d", s.highlighted)
}
if s.selectedCount() != 0 {
t.Errorf("expected 0 selected, got %d", s.selectedCount())
}
if s.focusOnButton {
t.Error("expected focusOnButton=false initially")
}
})
t.Run("InitialState_WithPrechecked", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item2", "item3"})
if s.selectedCount() != 2 {
t.Errorf("expected 2 selected, got %d", s.selectedCount())
}
if !s.checked[1] || !s.checked[2] {
t.Error("expected item2 and item3 to be checked")
}
})
t.Run("Prechecked_PreservesSelectionOrder", func(t *testing.T) {
// order matters: first checked = default model
s := newMultiSelectState(items, []string{"item3", "item1"})
if len(s.checkOrder) != 2 {
t.Fatalf("expected 2 in checkOrder, got %d", len(s.checkOrder))
}
if s.checkOrder[0] != 2 || s.checkOrder[1] != 0 {
t.Errorf("expected checkOrder=[2,0] (item3 first), got %v", s.checkOrder)
}
})
t.Run("Prechecked_IgnoresInvalidNames", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1", "nonexistent"})
if s.selectedCount() != 1 {
t.Errorf("expected 1 selected (nonexistent ignored), got %d", s.selectedCount())
}
})
t.Run("Toggle_ChecksUncheckedItem", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.toggleItem()
if !s.checked[0] {
t.Error("expected item1 to be checked after toggle")
}
})
t.Run("Toggle_UnchecksCheckedItem", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
s.toggleItem()
if s.checked[0] {
t.Error("expected item1 to be unchecked after toggle")
}
})
t.Run("Toggle_RemovesFromCheckOrder", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1", "item2", "item3"})
s.highlighted = 1 // toggle item2
s.toggleItem()
if len(s.checkOrder) != 2 {
t.Fatalf("expected 2 in checkOrder, got %d", len(s.checkOrder))
}
// should be [0, 2] (item1, item3) with item2 removed
if s.checkOrder[0] != 0 || s.checkOrder[1] != 2 {
t.Errorf("expected checkOrder=[0,2], got %v", s.checkOrder)
}
})
t.Run("Enter_TogglesWhenNotOnButton", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.handleInput(eventEnter, 0)
if !s.checked[0] {
t.Error("expected item1 to be checked after enter")
}
})
t.Run("Enter_OnButton_ReturnsSelection", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item2", "item1"})
s.focusOnButton = true
done, result, err := s.handleInput(eventEnter, 0)
if !done || err != nil {
t.Errorf("expected done=true, err=nil, got done=%v, err=%v", done, err)
}
// result should preserve selection order
if len(result) != 2 || result[0] != "item2" || result[1] != "item1" {
t.Errorf("expected [item2, item1], got %v", result)
}
})
t.Run("Enter_OnButton_EmptySelection_DoesNothing", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.focusOnButton = true
done, result, err := s.handleInput(eventEnter, 0)
if done || result != nil || err != nil {
t.Errorf("expected (false, nil, nil), got (%v, %v, %v)", done, result, err)
}
})
t.Run("Tab_SwitchesToButton_WhenHasSelection", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
s.handleInput(eventTab, 0)
if !s.focusOnButton {
t.Error("expected focus on button after tab")
}
})
t.Run("Tab_DoesNothing_WhenNoSelection", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.handleInput(eventTab, 0)
if s.focusOnButton {
t.Error("tab should not focus button when nothing selected")
}
})
t.Run("Tab_TogglesButtonFocus", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
s.handleInput(eventTab, 0)
if !s.focusOnButton {
t.Error("expected focus on button after first tab")
}
s.handleInput(eventTab, 0)
if s.focusOnButton {
t.Error("expected focus back on list after second tab")
}
})
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
done, result, err := s.handleInput(eventEscape, 0)
if !done || result != nil || err != errCancelled {
t.Errorf("expected (true, nil, errCancelled), got (%v, %v, %v)", done, result, err)
}
})
t.Run("IsDefault_TrueForFirstChecked", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item2", "item1"})
if !(len(s.checkOrder) > 0 && s.checkOrder[0] == 1) {
t.Error("expected item2 (idx 1) to be default (first checked)")
}
if len(s.checkOrder) > 0 && s.checkOrder[0] == 0 {
t.Error("expected item1 (idx 0) to NOT be default")
}
})
t.Run("IsDefault_FalseWhenNothingChecked", func(t *testing.T) {
s := newMultiSelectState(items, nil)
if len(s.checkOrder) > 0 && s.checkOrder[0] == 0 {
t.Error("expected isDefault=false when nothing checked")
}
})
t.Run("Down_MovesHighlight", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.handleInput(eventDown, 0)
if s.highlighted != 1 {
t.Errorf("expected highlighted=1, got %d", s.highlighted)
}
})
t.Run("Up_MovesHighlight", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.highlighted = 1
s.handleInput(eventUp, 0)
if s.highlighted != 0 {
t.Errorf("expected highlighted=0, got %d", s.highlighted)
}
})
t.Run("Arrow_ReturnsFocusFromButton", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
s.focusOnButton = true
s.handleInput(eventDown, 0)
if s.focusOnButton {
t.Error("expected focus to return to list on arrow key")
}
})
t.Run("Char_AppendsToFilter", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.handleInput(eventChar, 'x')
if s.filter != "x" {
t.Errorf("expected filter='x', got %q", s.filter)
}
})
t.Run("Char_ResetsHighlightAndScroll", func(t *testing.T) {
manyItems := make([]selectItem, 15)
for i := range manyItems {
manyItems[i] = selectItem{Name: string(rune('a' + i))}
}
s := newMultiSelectState(manyItems, nil)
s.highlighted = 10
s.scrollOffset = 5
s.handleInput(eventChar, 'x')
if s.highlighted != 0 {
t.Errorf("expected highlighted=0, got %d", s.highlighted)
}
if s.scrollOffset != 0 {
t.Errorf("expected scrollOffset=0, got %d", s.scrollOffset)
}
})
t.Run("Backspace_RemovesLastFilterChar", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.filter = "test"
s.handleInput(eventBackspace, 0)
if s.filter != "tes" {
t.Errorf("expected filter='tes', got %q", s.filter)
}
})
t.Run("Backspace_RemovesFocusFromButton", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
s.filter = "x"
s.focusOnButton = true
s.handleInput(eventBackspace, 0)
if s.focusOnButton {
t.Error("expected focusOnButton=false after backspace")
}
})
}
func TestParseInput(t *testing.T) {
t.Run("Enter", func(t *testing.T) {
event, char, err := parseInput(bytes.NewReader([]byte{13}))
if err != nil || event != eventEnter || char != 0 {
t.Errorf("expected (eventEnter, 0, nil), got (%v, %v, %v)", event, char, err)
}
})
t.Run("Escape", func(t *testing.T) {
event, _, err := parseInput(bytes.NewReader([]byte{27}))
if err != nil || event != eventEscape {
t.Errorf("expected eventEscape, got %v", event)
}
})
t.Run("CtrlC_TreatedAsEscape", func(t *testing.T) {
event, _, err := parseInput(bytes.NewReader([]byte{3}))
if err != nil || event != eventEscape {
t.Errorf("expected eventEscape for Ctrl+C, got %v", event)
}
})
t.Run("Tab", func(t *testing.T) {
event, _, err := parseInput(bytes.NewReader([]byte{9}))
if err != nil || event != eventTab {
t.Errorf("expected eventTab, got %v", event)
}
})
t.Run("Backspace", func(t *testing.T) {
event, _, err := parseInput(bytes.NewReader([]byte{127}))
if err != nil || event != eventBackspace {
t.Errorf("expected eventBackspace, got %v", event)
}
})
t.Run("UpArrow", func(t *testing.T) {
event, _, err := parseInput(bytes.NewReader([]byte{27, 91, 65}))
if err != nil || event != eventUp {
t.Errorf("expected eventUp, got %v", event)
}
})
t.Run("DownArrow", func(t *testing.T) {
event, _, err := parseInput(bytes.NewReader([]byte{27, 91, 66}))
if err != nil || event != eventDown {
t.Errorf("expected eventDown, got %v", event)
}
})
t.Run("PrintableChars", func(t *testing.T) {
tests := []struct {
name string
char byte
}{
{"lowercase", 'a'},
{"uppercase", 'Z'},
{"digit", '5'},
{"space", ' '},
{"tilde", '~'},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
event, char, err := parseInput(bytes.NewReader([]byte{tt.char}))
if err != nil || event != eventChar || char != tt.char {
t.Errorf("expected (eventChar, %q), got (%v, %q)", tt.char, event, char)
}
})
}
})
}
func TestRenderSelect(t *testing.T) {
items := []selectItem{
{Name: "item1", Description: "first item"},
{Name: "item2"},
}
t.Run("ShowsPromptAndItems", func(t *testing.T) {
s := newSelectState(items)
var buf bytes.Buffer
lineCount := renderSelect(&buf, "Select:", s)
output := buf.String()
if !strings.Contains(output, "Select:") {
t.Error("expected prompt in output")
}
if !strings.Contains(output, "item1") {
t.Error("expected item1 in output")
}
if !strings.Contains(output, "first item") {
t.Error("expected description in output")
}
if !strings.Contains(output, "item2") {
t.Error("expected item2 in output")
}
if lineCount != 3 { // 1 prompt + 2 items
t.Errorf("expected 3 lines, got %d", lineCount)
}
})
t.Run("EmptyFilteredList_ShowsNoMatches", func(t *testing.T) {
s := newSelectState(items)
s.filter = "xyz"
var buf bytes.Buffer
renderSelect(&buf, "Select:", s)
if !strings.Contains(buf.String(), "no matches") {
t.Error("expected 'no matches' message")
}
})
t.Run("LongList_ShowsRemainingCount", func(t *testing.T) {
manyItems := make([]selectItem, 15)
for i := range manyItems {
manyItems[i] = selectItem{Name: string(rune('a' + i))}
}
s := newSelectState(manyItems)
var buf bytes.Buffer
renderSelect(&buf, "Select:", s)
// 15 items - 10 displayed = 5 more
if !strings.Contains(buf.String(), "5 more") {
t.Error("expected '5 more' indicator")
}
})
}
func TestRenderMultiSelect(t *testing.T) {
items := []selectItem{
{Name: "item1"},
{Name: "item2"},
}
t.Run("ShowsCheckboxes", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
var buf bytes.Buffer
renderMultiSelect(&buf, "Select:", s)
output := buf.String()
if !strings.Contains(output, "[x]") {
t.Error("expected checked checkbox [x]")
}
if !strings.Contains(output, "[ ]") {
t.Error("expected unchecked checkbox [ ]")
}
})
t.Run("ShowsDefaultMarker", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
var buf bytes.Buffer
renderMultiSelect(&buf, "Select:", s)
if !strings.Contains(buf.String(), "(default)") {
t.Error("expected (default) marker for first checked item")
}
})
t.Run("ShowsSelectedCount", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1", "item2"})
var buf bytes.Buffer
renderMultiSelect(&buf, "Select:", s)
if !strings.Contains(buf.String(), "2 selected") {
t.Error("expected '2 selected' in output")
}
})
t.Run("NoSelection_ShowsHelperText", func(t *testing.T) {
s := newMultiSelectState(items, nil)
var buf bytes.Buffer
renderMultiSelect(&buf, "Select:", s)
if !strings.Contains(buf.String(), "Select at least one") {
t.Error("expected 'Select at least one' helper text")
}
})
}
func TestErrCancelled(t *testing.T) {
t.Run("NotNil", func(t *testing.T) {
if errCancelled == nil {
t.Error("errCancelled should not be nil")
}
})
t.Run("Message", func(t *testing.T) {
if errCancelled.Error() != "cancelled" {
t.Errorf("expected 'cancelled', got %q", errCancelled.Error())
}
})
}
// Edge case tests for selector.go
// TestSelectState_SingleItem verifies that single item list works without crash.
// List with only one item should still work.
func TestSelectState_SingleItem(t *testing.T) {
items := []selectItem{{Name: "only-one"}}
s := newSelectState(items)
// Down should do nothing (already at bottom)
s.handleInput(eventDown, 0)
if s.selected != 0 {
t.Errorf("down on single item: expected selected=0, got %d", s.selected)
}
// Up should do nothing (already at top)
s.handleInput(eventUp, 0)
if s.selected != 0 {
t.Errorf("up on single item: expected selected=0, got %d", s.selected)
}
// Enter should select the only item
done, result, err := s.handleInput(eventEnter, 0)
if !done || result != "only-one" || err != nil {
t.Errorf("enter on single item: expected (true, 'only-one', nil), got (%v, %q, %v)", done, result, err)
}
}
// TestSelectState_ExactlyMaxItems verifies boundary condition at maxDisplayedItems.
// List with exactly maxDisplayedItems items should not scroll.
func TestSelectState_ExactlyMaxItems(t *testing.T) {
items := make([]selectItem, maxDisplayedItems)
for i := range items {
items[i] = selectItem{Name: string(rune('a' + i))}
}
s := newSelectState(items)
// Move to last item
for range maxDisplayedItems - 1 {
s.handleInput(eventDown, 0)
}
if s.selected != maxDisplayedItems-1 {
t.Errorf("expected selected=%d, got %d", maxDisplayedItems-1, s.selected)
}
// Should not scroll when exactly at max
if s.scrollOffset != 0 {
t.Errorf("expected scrollOffset=0 for exactly maxDisplayedItems, got %d", s.scrollOffset)
}
// One more down should do nothing
s.handleInput(eventDown, 0)
if s.selected != maxDisplayedItems-1 {
t.Errorf("down at max: expected selected=%d, got %d", maxDisplayedItems-1, s.selected)
}
}
// TestFilterItems_RegexSpecialChars verifies that filter is literal, not regex.
// User typing "model.v1" shouldn't match "modelsv1".
func TestFilterItems_RegexSpecialChars(t *testing.T) {
items := []selectItem{
{Name: "model.v1"},
{Name: "modelsv1"},
{Name: "model-v1"},
}
// Filter with dot should only match literal dot
result := filterItems(items, "model.v1")
if len(result) != 1 {
t.Errorf("expected 1 exact match, got %d", len(result))
}
if len(result) > 0 && result[0].Name != "model.v1" {
t.Errorf("expected 'model.v1', got %s", result[0].Name)
}
// Other regex special chars should be literal too
items2 := []selectItem{
{Name: "test[0]"},
{Name: "test0"},
{Name: "test(1)"},
}
result2 := filterItems(items2, "test[0]")
if len(result2) != 1 || result2[0].Name != "test[0]" {
t.Errorf("expected only 'test[0]', got %v", result2)
}
}
// TestMultiSelectState_DuplicateNames documents handling of duplicate item names.
// itemIndex uses name as key - duplicates cause collision. This documents
// the current behavior: the last index for a duplicate name is stored
func TestMultiSelectState_DuplicateNames(t *testing.T) {
// Duplicate names - this is an edge case that shouldn't happen in practice
items := []selectItem{
{Name: "duplicate"},
{Name: "duplicate"},
{Name: "unique"},
}
s := newMultiSelectState(items, nil)
// DOCUMENTED BEHAVIOR: itemIndex maps name to LAST index
// When there are duplicates, only the last occurrence's index is stored
if s.itemIndex["duplicate"] != 1 {
t.Errorf("itemIndex should map 'duplicate' to last index (1), got %d", s.itemIndex["duplicate"])
}
// Toggle item at highlighted=0 (first "duplicate")
// Due to name collision, toggleItem uses itemIndex["duplicate"] = 1
// So it actually toggles the SECOND duplicate item, not the first
s.toggleItem()
// This documents the potentially surprising behavior:
// We toggled at highlighted=0, but itemIndex lookup returned 1
if !s.checked[1] {
t.Error("toggle should check index 1 (due to name collision in itemIndex)")
}
if s.checked[0] {
t.Log("Note: index 0 is NOT checked, even though highlighted=0 (name collision behavior)")
}
}
// TestSelectState_FilterReducesBelowSelection verifies selection resets when filter reduces list.
// Prevents index-out-of-bounds on next keystroke
func TestSelectState_FilterReducesBelowSelection(t *testing.T) {
items := []selectItem{
{Name: "apple"},
{Name: "banana"},
{Name: "cherry"},
}
s := newSelectState(items)
s.selected = 2 // Select "cherry"
// Type a filter that removes cherry from results
s.handleInput(eventChar, 'a') // Filter to "a" - matches "apple" and "banana"
// Selection should reset to 0
if s.selected != 0 {
t.Errorf("expected selected=0 after filter, got %d", s.selected)
}
filtered := s.filtered()
if len(filtered) != 2 {
t.Errorf("expected 2 filtered items, got %d", len(filtered))
}
}
// TestFilterItems_UnicodeCharacters verifies filtering works with UTF-8.
// Model names might contain unicode characters
func TestFilterItems_UnicodeCharacters(t *testing.T) {
items := []selectItem{
{Name: "llama-日本語"},
{Name: "模型-chinese"},
{Name: "émoji-🦙"},
{Name: "regular-model"},
}
t.Run("filter japanese", func(t *testing.T) {
result := filterItems(items, "日本")
if len(result) != 1 || result[0].Name != "llama-日本語" {
t.Errorf("expected llama-日本語, got %v", result)
}
})
t.Run("filter chinese", func(t *testing.T) {
result := filterItems(items, "模型")
if len(result) != 1 || result[0].Name != "模型-chinese" {
t.Errorf("expected 模型-chinese, got %v", result)
}
})
t.Run("filter emoji", func(t *testing.T) {
result := filterItems(items, "🦙")
if len(result) != 1 || result[0].Name != "émoji-🦙" {
t.Errorf("expected émoji-🦙, got %v", result)
}
})
t.Run("filter accented char", func(t *testing.T) {
result := filterItems(items, "émoji")
if len(result) != 1 || result[0].Name != "émoji-🦙" {
t.Errorf("expected émoji-🦙, got %v", result)
}
})
}
// TestMultiSelectState_FilterReducesBelowHighlight verifies highlight resets when filter reduces list.
func TestMultiSelectState_FilterReducesBelowHighlight(t *testing.T) {
items := []selectItem{
{Name: "apple"},
{Name: "banana"},
{Name: "cherry"},
}
s := newMultiSelectState(items, nil)
s.highlighted = 2 // Highlight "cherry"
// Type a filter that removes cherry
s.handleInput(eventChar, 'a')
if s.highlighted != 0 {
t.Errorf("expected highlighted=0 after filter, got %d", s.highlighted)
}
}
// TestMultiSelectState_EmptyItems verifies handling of empty item list.
// Empty list should be handled gracefully.
func TestMultiSelectState_EmptyItems(t *testing.T) {
s := newMultiSelectState([]selectItem{}, nil)
// Toggle should not panic on empty list
s.toggleItem()
if s.selectedCount() != 0 {
t.Errorf("expected 0 selected for empty list, got %d", s.selectedCount())
}
// Render should handle empty list
var buf bytes.Buffer
lineCount := renderMultiSelect(&buf, "Select:", s)
if lineCount == 0 {
t.Error("renderMultiSelect should produce output even for empty list")
}
if !strings.Contains(buf.String(), "no matches") {
t.Error("expected 'no matches' for empty list")
}
}
// TestSelectState_RenderWithDescriptions verifies rendering items with descriptions.
func TestSelectState_RenderWithDescriptions(t *testing.T) {
items := []selectItem{
{Name: "item1", Description: "First item description"},
{Name: "item2", Description: ""},
{Name: "item3", Description: "Third item"},
}
s := newSelectState(items)
var buf bytes.Buffer
renderSelect(&buf, "Select:", s)
output := buf.String()
if !strings.Contains(output, "First item description") {
t.Error("expected description to be rendered")
}
if !strings.Contains(output, "item2") {
t.Error("expected item without description to be rendered")
}
}