From 199c41e16edbbf90c38a0d3117c70070e410db48 Mon Sep 17 00:00:00 2001 From: Parth Sareen Date: Thu, 22 Jan 2026 23:17:11 -0500 Subject: [PATCH] cmd: `ollama config` command to help configure integrations to use Ollama (#13712) --- cmd/cmd.go | 2 + cmd/config/claude.go | 36 ++ cmd/config/claude_test.go | 42 ++ cmd/config/codex.go | 61 +++ cmd/config/codex_test.go | 28 + cmd/config/config.go | 115 ++++ cmd/config/config_test.go | 373 +++++++++++++ cmd/config/droid.go | 174 ++++++ cmd/config/droid_test.go | 454 ++++++++++++++++ cmd/config/files.go | 99 ++++ cmd/config/files_test.go | 502 ++++++++++++++++++ cmd/config/integrations.go | 361 +++++++++++++ cmd/config/integrations_test.go | 188 +++++++ cmd/config/opencode.go | 203 +++++++ cmd/config/opencode_test.go | 437 +++++++++++++++ cmd/config/selector.go | 499 +++++++++++++++++ cmd/config/selector_test.go | 913 ++++++++++++++++++++++++++++++++ 17 files changed, 4487 insertions(+) create mode 100644 cmd/config/claude.go create mode 100644 cmd/config/claude_test.go create mode 100644 cmd/config/codex.go create mode 100644 cmd/config/codex_test.go create mode 100644 cmd/config/config.go create mode 100644 cmd/config/config_test.go create mode 100644 cmd/config/droid.go create mode 100644 cmd/config/droid_test.go create mode 100644 cmd/config/files.go create mode 100644 cmd/config/files_test.go create mode 100644 cmd/config/integrations.go create mode 100644 cmd/config/integrations_test.go create mode 100644 cmd/config/opencode.go create mode 100644 cmd/config/opencode_test.go create mode 100644 cmd/config/selector.go create mode 100644 cmd/config/selector_test.go diff --git a/cmd/cmd.go b/cmd/cmd.go index c9c89af56..e4ad0366b 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -35,6 +35,7 @@ import ( "golang.org/x/term" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/cmd/config" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" "github.com/ollama/ollama/parser" @@ -2026,6 +2027,7 @@ func NewCLI() *cobra.Command { copyCmd, deleteCmd, runnerCmd, + config.ConfigCmd(checkServerHeartbeat), ) return rootCmd diff --git a/cmd/config/claude.go b/cmd/config/claude.go new file mode 100644 index 000000000..7320593f5 --- /dev/null +++ b/cmd/config/claude.go @@ -0,0 +1,36 @@ +package config + +import ( + "fmt" + "os" + "os/exec" +) + +// Claude implements Runner for Claude Code integration +type Claude struct{} + +func (c *Claude) String() string { return "Claude Code" } + +func (c *Claude) args(model string) []string { + if model != "" { + return []string{"--model", model} + } + return nil +} + +func (c *Claude) Run(model string) error { + if _, err := exec.LookPath("claude"); err != nil { + return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart") + } + + cmd := exec.Command("claude", c.args(model)...) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.Env = append(os.Environ(), + "ANTHROPIC_BASE_URL=http://localhost:11434", + "ANTHROPIC_API_KEY=", + "ANTHROPIC_AUTH_TOKEN=ollama", + ) + return cmd.Run() +} diff --git a/cmd/config/claude_test.go b/cmd/config/claude_test.go new file mode 100644 index 000000000..32b3b7ede --- /dev/null +++ b/cmd/config/claude_test.go @@ -0,0 +1,42 @@ +package config + +import ( + "slices" + "testing" +) + +func TestClaudeIntegration(t *testing.T) { + c := &Claude{} + + t.Run("String", func(t *testing.T) { + if got := c.String(); got != "Claude Code" { + t.Errorf("String() = %q, want %q", got, "Claude Code") + } + }) + + t.Run("implements Runner", func(t *testing.T) { + var _ Runner = c + }) +} + +func TestClaudeArgs(t *testing.T) { + c := &Claude{} + + tests := []struct { + name string + model string + want []string + }{ + {"with model", "llama3.2", []string{"--model", "llama3.2"}}, + {"empty model", "", nil}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := c.args(tt.model) + if !slices.Equal(got, tt.want) { + t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want) + } + }) + } +} diff --git a/cmd/config/codex.go b/cmd/config/codex.go new file mode 100644 index 000000000..f421e1f6a --- /dev/null +++ b/cmd/config/codex.go @@ -0,0 +1,61 @@ +package config + +import ( + "fmt" + "os" + "os/exec" + "strings" + + "golang.org/x/mod/semver" +) + +// Codex implements Runner for Codex integration +type Codex struct{} + +func (c *Codex) String() string { return "Codex" } + +func (c *Codex) args(model string) []string { + args := []string{"--oss"} + if model != "" { + args = append(args, "-m", model) + } + return args +} + +func (c *Codex) Run(model string) error { + if err := checkCodexVersion(); err != nil { + return err + } + + cmd := exec.Command("codex", c.args(model)...) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + return cmd.Run() +} + +func checkCodexVersion() error { + if _, err := exec.LookPath("codex"); err != nil { + return fmt.Errorf("codex is not installed, install with: npm install -g @openai/codex") + } + + out, err := exec.Command("codex", "--version").Output() + if err != nil { + return fmt.Errorf("failed to get codex version: %w", err) + } + + // Parse output like "codex-cli 0.87.0" + fields := strings.Fields(strings.TrimSpace(string(out))) + if len(fields) < 2 { + return fmt.Errorf("unexpected codex version output: %s", string(out)) + } + + version := "v" + fields[len(fields)-1] + minVersion := "v0.81.0" + + if semver.Compare(version, minVersion) < 0 { + return fmt.Errorf("codex version %s is too old, minimum required is %s, update with: npm update -g @openai/codex", fields[len(fields)-1], "0.81.0") + } + + return nil +} diff --git a/cmd/config/codex_test.go b/cmd/config/codex_test.go new file mode 100644 index 000000000..2fe614211 --- /dev/null +++ b/cmd/config/codex_test.go @@ -0,0 +1,28 @@ +package config + +import ( + "slices" + "testing" +) + +func TestCodexArgs(t *testing.T) { + c := &Codex{} + + tests := []struct { + name string + model string + want []string + }{ + {"with model", "llama3.2", []string{"--oss", "-m", "llama3.2"}}, + {"empty model", "", []string{"--oss"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := c.args(tt.model) + if !slices.Equal(got, tt.want) { + t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want) + } + }) + } +} diff --git a/cmd/config/config.go b/cmd/config/config.go new file mode 100644 index 000000000..598423696 --- /dev/null +++ b/cmd/config/config.go @@ -0,0 +1,115 @@ +// Package config provides integration configuration for external coding tools +// (Claude Code, Codex, Droid, OpenCode) to use Ollama models. +package config + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "strings" +) + +type integration struct { + Models []string `json:"models"` +} + +type config struct { + Integrations map[string]*integration `json:"integrations"` +} + +func configPath() (string, error) { + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + return filepath.Join(home, ".ollama", "config", "config.json"), nil +} + +func load() (*config, error) { + path, err := configPath() + if err != nil { + return nil, err + } + + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return &config{Integrations: make(map[string]*integration)}, nil + } + return nil, err + } + + var cfg config + if err := json.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("failed to parse config: %w, at: %s", err, path) + } + if cfg.Integrations == nil { + cfg.Integrations = make(map[string]*integration) + } + return &cfg, nil +} + +func save(cfg *config) error { + path, err := configPath() + if err != nil { + return err + } + + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return err + } + + data, err := json.MarshalIndent(cfg, "", " ") + if err != nil { + return err + } + + return writeWithBackup(path, data) +} + +func saveIntegration(appName string, models []string) error { + if appName == "" { + return errors.New("app name cannot be empty") + } + + cfg, err := load() + if err != nil { + return err + } + + cfg.Integrations[strings.ToLower(appName)] = &integration{ + Models: models, + } + + return save(cfg) +} + +func loadIntegration(appName string) (*integration, error) { + cfg, err := load() + if err != nil { + return nil, err + } + + ic, ok := cfg.Integrations[strings.ToLower(appName)] + if !ok { + return nil, os.ErrNotExist + } + + return ic, nil +} + +func listIntegrations() ([]integration, error) { + cfg, err := load() + if err != nil { + return nil, err + } + + result := make([]integration, 0, len(cfg.Integrations)) + for _, ic := range cfg.Integrations { + result = append(result, *ic) + } + + return result, nil +} diff --git a/cmd/config/config_test.go b/cmd/config/config_test.go new file mode 100644 index 000000000..2f9823f9e --- /dev/null +++ b/cmd/config/config_test.go @@ -0,0 +1,373 @@ +package config + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +// setTestHome sets both HOME (Unix) and USERPROFILE (Windows) for cross-platform tests +func setTestHome(t *testing.T, dir string) { + t.Setenv("HOME", dir) + t.Setenv("USERPROFILE", dir) +} + +// editorPaths is a test helper that safely calls Paths if the runner implements Editor +func editorPaths(r Runner) []string { + if editor, ok := r.(Editor); ok { + return editor.Paths() + } + return nil +} + +func TestIntegrationConfig(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + t.Run("save and load round-trip", func(t *testing.T) { + models := []string{"llama3.2", "mistral", "qwen2.5"} + if err := saveIntegration("claude", models); err != nil { + t.Fatal(err) + } + + config, err := loadIntegration("claude") + if err != nil { + t.Fatal(err) + } + + if len(config.Models) != len(models) { + t.Errorf("expected %d models, got %d", len(models), len(config.Models)) + } + for i, m := range models { + if config.Models[i] != m { + t.Errorf("model %d: expected %s, got %s", i, m, config.Models[i]) + } + } + }) + + t.Run("defaultModel returns first model", func(t *testing.T) { + saveIntegration("codex", []string{"model-a", "model-b"}) + + config, _ := loadIntegration("codex") + defaultModel := "" + if len(config.Models) > 0 { + defaultModel = config.Models[0] + } + if defaultModel != "model-a" { + t.Errorf("expected model-a, got %s", defaultModel) + } + }) + + t.Run("defaultModel returns empty for no models", func(t *testing.T) { + config := &integration{Models: []string{}} + defaultModel := "" + if len(config.Models) > 0 { + defaultModel = config.Models[0] + } + if defaultModel != "" { + t.Errorf("expected empty string, got %s", defaultModel) + } + }) + + t.Run("app name is case-insensitive", func(t *testing.T) { + saveIntegration("Claude", []string{"model-x"}) + + config, err := loadIntegration("claude") + if err != nil { + t.Fatal(err) + } + defaultModel := "" + if len(config.Models) > 0 { + defaultModel = config.Models[0] + } + if defaultModel != "model-x" { + t.Errorf("expected model-x, got %s", defaultModel) + } + }) + + t.Run("multiple integrations in single file", func(t *testing.T) { + saveIntegration("app1", []string{"model-1"}) + saveIntegration("app2", []string{"model-2"}) + + config1, _ := loadIntegration("app1") + config2, _ := loadIntegration("app2") + + defaultModel1 := "" + if len(config1.Models) > 0 { + defaultModel1 = config1.Models[0] + } + defaultModel2 := "" + if len(config2.Models) > 0 { + defaultModel2 = config2.Models[0] + } + if defaultModel1 != "model-1" { + t.Errorf("expected model-1, got %s", defaultModel1) + } + if defaultModel2 != "model-2" { + t.Errorf("expected model-2, got %s", defaultModel2) + } + }) +} + +func TestListIntegrations(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + t.Run("returns empty when no integrations", func(t *testing.T) { + configs, err := listIntegrations() + if err != nil { + t.Fatal(err) + } + if len(configs) != 0 { + t.Errorf("expected 0 integrations, got %d", len(configs)) + } + }) + + t.Run("returns all saved integrations", func(t *testing.T) { + saveIntegration("claude", []string{"model-1"}) + saveIntegration("droid", []string{"model-2"}) + + configs, err := listIntegrations() + if err != nil { + t.Fatal(err) + } + if len(configs) != 2 { + t.Errorf("expected 2 integrations, got %d", len(configs)) + } + }) +} + +func TestEditorPaths(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + t.Run("returns empty for claude (no Editor)", func(t *testing.T) { + r := integrations["claude"] + paths := editorPaths(r) + if len(paths) != 0 { + t.Errorf("expected no paths for claude, got %v", paths) + } + }) + + t.Run("returns empty for codex (no Editor)", func(t *testing.T) { + r := integrations["codex"] + paths := editorPaths(r) + if len(paths) != 0 { + t.Errorf("expected no paths for codex, got %v", paths) + } + }) + + t.Run("returns empty for droid when no config exists", func(t *testing.T) { + r := integrations["droid"] + paths := editorPaths(r) + if len(paths) != 0 { + t.Errorf("expected no paths, got %v", paths) + } + }) + + t.Run("returns path for droid when config exists", func(t *testing.T) { + settingsDir, _ := os.UserHomeDir() + settingsDir = filepath.Join(settingsDir, ".factory") + os.MkdirAll(settingsDir, 0o755) + os.WriteFile(filepath.Join(settingsDir, "settings.json"), []byte(`{}`), 0o644) + + r := integrations["droid"] + paths := editorPaths(r) + if len(paths) != 1 { + t.Errorf("expected 1 path, got %d", len(paths)) + } + }) + + t.Run("returns paths for opencode when configs exist", func(t *testing.T) { + home, _ := os.UserHomeDir() + configDir := filepath.Join(home, ".config", "opencode") + stateDir := filepath.Join(home, ".local", "state", "opencode") + os.MkdirAll(configDir, 0o755) + os.MkdirAll(stateDir, 0o755) + os.WriteFile(filepath.Join(configDir, "opencode.json"), []byte(`{}`), 0o644) + os.WriteFile(filepath.Join(stateDir, "model.json"), []byte(`{}`), 0o644) + + r := integrations["opencode"] + paths := editorPaths(r) + if len(paths) != 2 { + t.Errorf("expected 2 paths, got %d: %v", len(paths), paths) + } + }) +} + +func TestLoadIntegration_CorruptedJSON(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + // Create corrupted config.json file + dir := filepath.Join(tmpDir, ".ollama", "config") + os.MkdirAll(dir, 0o755) + os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644) + + // Corrupted file is treated as empty, so loadIntegration returns not found + _, err := loadIntegration("test") + if err == nil { + t.Error("expected error for nonexistent integration in corrupted file") + } +} + +func TestSaveIntegration_NilModels(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + if err := saveIntegration("test", nil); err != nil { + t.Fatalf("saveIntegration with nil models failed: %v", err) + } + + config, err := loadIntegration("test") + if err != nil { + t.Fatalf("loadIntegration failed: %v", err) + } + + if config.Models == nil { + // nil is acceptable + } else if len(config.Models) != 0 { + t.Errorf("expected empty or nil models, got %v", config.Models) + } +} + +func TestSaveIntegration_EmptyAppName(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + err := saveIntegration("", []string{"model"}) + if err == nil { + t.Error("expected error for empty app name, got nil") + } + if err != nil && !strings.Contains(err.Error(), "app name cannot be empty") { + t.Errorf("expected 'app name cannot be empty' error, got: %v", err) + } +} + +func TestLoadIntegration_NonexistentIntegration(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + _, err := loadIntegration("nonexistent") + if err == nil { + t.Error("expected error for nonexistent integration, got nil") + } + if !os.IsNotExist(err) { + t.Logf("error type is os.ErrNotExist as expected: %v", err) + } +} + +func TestConfigPath(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + path, err := configPath() + if err != nil { + t.Fatal(err) + } + + expected := filepath.Join(tmpDir, ".ollama", "config", "config.json") + if path != expected { + t.Errorf("expected %s, got %s", expected, path) + } +} + +func TestLoad(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + t.Run("returns empty config when file does not exist", func(t *testing.T) { + cfg, err := load() + if err != nil { + t.Fatal(err) + } + if cfg == nil { + t.Fatal("expected non-nil config") + } + if cfg.Integrations == nil { + t.Error("expected non-nil Integrations map") + } + if len(cfg.Integrations) != 0 { + t.Errorf("expected empty Integrations, got %d", len(cfg.Integrations)) + } + }) + + t.Run("loads existing config", func(t *testing.T) { + path, _ := configPath() + os.MkdirAll(filepath.Dir(path), 0o755) + os.WriteFile(path, []byte(`{"integrations":{"test":{"models":["model-a"]}}}`), 0o644) + + cfg, err := load() + if err != nil { + t.Fatal(err) + } + if cfg.Integrations["test"] == nil { + t.Fatal("expected test integration") + } + if len(cfg.Integrations["test"].Models) != 1 { + t.Errorf("expected 1 model, got %d", len(cfg.Integrations["test"].Models)) + } + }) + + t.Run("returns error for corrupted JSON", func(t *testing.T) { + path, _ := configPath() + os.MkdirAll(filepath.Dir(path), 0o755) + os.WriteFile(path, []byte(`{corrupted`), 0o644) + + _, err := load() + if err == nil { + t.Error("expected error for corrupted JSON") + } + }) +} + +func TestSave(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + t.Run("creates config file", func(t *testing.T) { + cfg := &config{ + Integrations: map[string]*integration{ + "test": {Models: []string{"model-a", "model-b"}}, + }, + } + + if err := save(cfg); err != nil { + t.Fatal(err) + } + + path, _ := configPath() + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Error("config file was not created") + } + }) + + t.Run("round-trip preserves data", func(t *testing.T) { + cfg := &config{ + Integrations: map[string]*integration{ + "claude": {Models: []string{"llama3.2", "mistral"}}, + "codex": {Models: []string{"qwen2.5"}}, + }, + } + + if err := save(cfg); err != nil { + t.Fatal(err) + } + + loaded, err := load() + if err != nil { + t.Fatal(err) + } + + if len(loaded.Integrations) != 2 { + t.Errorf("expected 2 integrations, got %d", len(loaded.Integrations)) + } + if loaded.Integrations["claude"] == nil { + t.Error("missing claude integration") + } + if len(loaded.Integrations["claude"].Models) != 2 { + t.Errorf("expected 2 models for claude, got %d", len(loaded.Integrations["claude"].Models)) + } + }) +} diff --git a/cmd/config/droid.go b/cmd/config/droid.go new file mode 100644 index 000000000..ece40d682 --- /dev/null +++ b/cmd/config/droid.go @@ -0,0 +1,174 @@ +package config + +import ( + "encoding/json" + "fmt" + "os" + "os/exec" + "path/filepath" + "slices" + "strings" +) + +// Droid implements Runner and Editor for Droid integration +type Droid struct{} + +// droidModelEntry represents a custom model entry in Droid's settings.json +type droidModelEntry struct { + Model string `json:"model"` + DisplayName string `json:"displayName"` + BaseURL string `json:"baseUrl"` + APIKey string `json:"apiKey"` + Provider string `json:"provider"` + MaxOutputTokens int `json:"maxOutputTokens"` + SupportsImages bool `json:"supportsImages"` + ID string `json:"id"` + Index int `json:"index"` +} + +func (d *Droid) String() string { return "Droid" } + +func (d *Droid) Run(model string) error { + if _, err := exec.LookPath("droid"); err != nil { + return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart") + } + + // Call Edit() to ensure config is up-to-date before launch + models := []string{model} + if config, err := loadIntegration("droid"); err == nil && len(config.Models) > 0 { + models = config.Models + } + if err := d.Edit(models); err != nil { + return fmt.Errorf("setup failed: %w", err) + } + + cmd := exec.Command("droid") + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + return cmd.Run() +} + +func (d *Droid) Paths() []string { + home, err := os.UserHomeDir() + if err != nil { + return nil + } + p := filepath.Join(home, ".factory", "settings.json") + if _, err := os.Stat(p); err == nil { + return []string{p} + } + return nil +} + +func (d *Droid) Edit(models []string) error { + if len(models) == 0 { + return nil + } + + home, err := os.UserHomeDir() + if err != nil { + return err + } + + settingsPath := filepath.Join(home, ".factory", "settings.json") + if err := os.MkdirAll(filepath.Dir(settingsPath), 0o755); err != nil { + return err + } + + settings := make(map[string]any) + if data, err := os.ReadFile(settingsPath); err == nil { + if err := json.Unmarshal(data, &settings); err != nil { + return fmt.Errorf("failed to parse settings file: %w, at: %s", err, settingsPath) + } + } + + customModels, _ := settings["customModels"].([]any) + + // Keep only non-Ollama models (we'll rebuild Ollama models fresh) + nonOllamaModels := slices.DeleteFunc(slices.Clone(customModels), isOllamaModelEntry) + + // Build new Ollama model entries with sequential indices (0, 1, 2, ...) + var ollamaModels []any + var defaultModelID string + for i, model := range models { + modelID := fmt.Sprintf("custom:%s-[Ollama]-%d", model, i) + ollamaModels = append(ollamaModels, droidModelEntry{ + Model: model, + DisplayName: model, + BaseURL: "http://localhost:11434/v1", + APIKey: "ollama", + Provider: "generic-chat-completion-api", + MaxOutputTokens: 64000, + SupportsImages: false, + ID: modelID, + Index: i, + }) + if i == 0 { + defaultModelID = modelID + } + } + + settings["customModels"] = append(ollamaModels, nonOllamaModels...) + + sessionSettings, ok := settings["sessionDefaultSettings"].(map[string]any) + if !ok { + sessionSettings = make(map[string]any) + } + sessionSettings["model"] = defaultModelID + + if effort, ok := sessionSettings["reasoningEffort"].(string); !ok || !isValidReasoningEffort(effort) { + sessionSettings["reasoningEffort"] = "none" + } + + settings["sessionDefaultSettings"] = sessionSettings + + data, err := json.MarshalIndent(settings, "", " ") + if err != nil { + return err + } + return writeWithBackup(settingsPath, data) +} + +func (d *Droid) Models() []string { + home, err := os.UserHomeDir() + if err != nil { + return nil + } + settings, err := readJSONFile(filepath.Join(home, ".factory", "settings.json")) + if err != nil { + return nil + } + + customModels, _ := settings["customModels"].([]any) + + var result []string + for _, m := range customModels { + if !isOllamaModelEntry(m) { + continue + } + entry, ok := m.(map[string]any) + if !ok { + continue + } + if model, _ := entry["model"].(string); model != "" { + result = append(result, model) + } + } + return result +} + +var validReasoningEfforts = []string{"high", "medium", "low", "none"} + +func isValidReasoningEffort(effort string) bool { + return slices.Contains(validReasoningEfforts, effort) +} + +func isOllamaModelEntry(m any) bool { + entry, ok := m.(map[string]any) + if !ok { + return false + } + id, _ := entry["id"].(string) + return strings.Contains(id, "-[Ollama]-") +} diff --git a/cmd/config/droid_test.go b/cmd/config/droid_test.go new file mode 100644 index 000000000..d158ad3a5 --- /dev/null +++ b/cmd/config/droid_test.go @@ -0,0 +1,454 @@ +package config + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" +) + +func TestDroidIntegration(t *testing.T) { + d := &Droid{} + + t.Run("String", func(t *testing.T) { + if got := d.String(); got != "Droid" { + t.Errorf("String() = %q, want %q", got, "Droid") + } + }) + + t.Run("implements Runner", func(t *testing.T) { + var _ Runner = d + }) + + t.Run("implements Editor", func(t *testing.T) { + var _ Editor = d + }) +} + +func TestDroidEdit(t *testing.T) { + d := &Droid{} + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + settingsDir := filepath.Join(tmpDir, ".factory") + settingsPath := filepath.Join(settingsDir, "settings.json") + + cleanup := func() { + os.RemoveAll(settingsDir) + } + + readSettings := func() map[string]any { + data, _ := os.ReadFile(settingsPath) + var settings map[string]any + json.Unmarshal(data, &settings) + return settings + } + + getCustomModels := func(settings map[string]any) []map[string]any { + models, ok := settings["customModels"].([]any) + if !ok { + return nil + } + var result []map[string]any + for _, m := range models { + if entry, ok := m.(map[string]any); ok { + result = append(result, entry) + } + } + return result + } + + t.Run("fresh install creates models with sequential indices", func(t *testing.T) { + cleanup() + if err := d.Edit([]string{"model-a", "model-b"}); err != nil { + t.Fatal(err) + } + + settings := readSettings() + models := getCustomModels(settings) + + if len(models) != 2 { + t.Fatalf("expected 2 models, got %d", len(models)) + } + + // Check first model + if models[0]["model"] != "model-a" { + t.Errorf("expected model-a, got %s", models[0]["model"]) + } + if models[0]["id"] != "custom:model-a-[Ollama]-0" { + t.Errorf("expected custom:model-a-[Ollama]-0, got %s", models[0]["id"]) + } + if models[0]["index"] != float64(0) { + t.Errorf("expected index 0, got %v", models[0]["index"]) + } + + // Check second model + if models[1]["model"] != "model-b" { + t.Errorf("expected model-b, got %s", models[1]["model"]) + } + if models[1]["id"] != "custom:model-b-[Ollama]-1" { + t.Errorf("expected custom:model-b-[Ollama]-1, got %s", models[1]["id"]) + } + if models[1]["index"] != float64(1) { + t.Errorf("expected index 1, got %v", models[1]["index"]) + } + }) + + t.Run("sets sessionDefaultSettings.model to first model ID", func(t *testing.T) { + cleanup() + if err := d.Edit([]string{"model-a", "model-b"}); err != nil { + t.Fatal(err) + } + + settings := readSettings() + session, ok := settings["sessionDefaultSettings"].(map[string]any) + if !ok { + t.Fatal("sessionDefaultSettings not found") + } + if session["model"] != "custom:model-a-[Ollama]-0" { + t.Errorf("expected custom:model-a-[Ollama]-0, got %s", session["model"]) + } + }) + + t.Run("re-indexes when models removed", func(t *testing.T) { + cleanup() + // Add three models + d.Edit([]string{"model-a", "model-b", "model-c"}) + + // Remove middle model + d.Edit([]string{"model-a", "model-c"}) + + settings := readSettings() + models := getCustomModels(settings) + + if len(models) != 2 { + t.Fatalf("expected 2 models, got %d", len(models)) + } + + // Check indices are sequential 0, 1 + if models[0]["index"] != float64(0) { + t.Errorf("expected index 0, got %v", models[0]["index"]) + } + if models[1]["index"] != float64(1) { + t.Errorf("expected index 1, got %v", models[1]["index"]) + } + + // Check IDs match new indices + if models[0]["id"] != "custom:model-a-[Ollama]-0" { + t.Errorf("expected custom:model-a-[Ollama]-0, got %s", models[0]["id"]) + } + if models[1]["id"] != "custom:model-c-[Ollama]-1" { + t.Errorf("expected custom:model-c-[Ollama]-1, got %s", models[1]["id"]) + } + }) + + t.Run("preserves non-Ollama custom models", func(t *testing.T) { + cleanup() + os.MkdirAll(settingsDir, 0o755) + // Pre-existing non-Ollama model + os.WriteFile(settingsPath, []byte(`{ + "customModels": [ + {"model": "gpt-4", "displayName": "GPT-4", "provider": "openai"} + ] + }`), 0o644) + + d.Edit([]string{"model-a"}) + + settings := readSettings() + models := getCustomModels(settings) + + if len(models) != 2 { + t.Fatalf("expected 2 models (1 Ollama + 1 non-Ollama), got %d", len(models)) + } + + // Ollama model should be first + if models[0]["model"] != "model-a" { + t.Errorf("expected Ollama model first, got %s", models[0]["model"]) + } + + // Non-Ollama model should be preserved at end + if models[1]["model"] != "gpt-4" { + t.Errorf("expected gpt-4 preserved, got %s", models[1]["model"]) + } + }) + + t.Run("preserves other settings", func(t *testing.T) { + cleanup() + os.MkdirAll(settingsDir, 0o755) + os.WriteFile(settingsPath, []byte(`{ + "theme": "dark", + "enableHooks": true, + "sessionDefaultSettings": {"autonomyMode": "auto-high"} + }`), 0o644) + + d.Edit([]string{"model-a"}) + + settings := readSettings() + + if settings["theme"] != "dark" { + t.Error("theme was not preserved") + } + if settings["enableHooks"] != true { + t.Error("enableHooks was not preserved") + } + + session := settings["sessionDefaultSettings"].(map[string]any) + if session["autonomyMode"] != "auto-high" { + t.Error("autonomyMode was not preserved") + } + }) + + t.Run("required fields present", func(t *testing.T) { + cleanup() + d.Edit([]string{"test-model"}) + + settings := readSettings() + models := getCustomModels(settings) + + if len(models) != 1 { + t.Fatal("expected 1 model") + } + + model := models[0] + requiredFields := []string{"model", "displayName", "baseUrl", "apiKey", "provider", "maxOutputTokens", "id", "index"} + for _, field := range requiredFields { + if model[field] == nil { + t.Errorf("missing required field: %s", field) + } + } + + if model["baseUrl"] != "http://localhost:11434/v1" { + t.Errorf("unexpected baseUrl: %s", model["baseUrl"]) + } + if model["apiKey"] != "ollama" { + t.Errorf("unexpected apiKey: %s", model["apiKey"]) + } + if model["provider"] != "generic-chat-completion-api" { + t.Errorf("unexpected provider: %s", model["provider"]) + } + }) + + t.Run("fixes invalid reasoningEffort", func(t *testing.T) { + cleanup() + os.MkdirAll(settingsDir, 0o755) + // Pre-existing settings with invalid reasoningEffort + os.WriteFile(settingsPath, []byte(`{ + "sessionDefaultSettings": {"reasoningEffort": "off"} + }`), 0o644) + + d.Edit([]string{"model-a"}) + + settings := readSettings() + session := settings["sessionDefaultSettings"].(map[string]any) + + if session["reasoningEffort"] != "none" { + t.Errorf("expected reasoningEffort to be fixed to 'none', got %s", session["reasoningEffort"]) + } + }) + + t.Run("preserves valid reasoningEffort", func(t *testing.T) { + cleanup() + os.MkdirAll(settingsDir, 0o755) + os.WriteFile(settingsPath, []byte(`{ + "sessionDefaultSettings": {"reasoningEffort": "high"} + }`), 0o644) + + d.Edit([]string{"model-a"}) + + settings := readSettings() + session := settings["sessionDefaultSettings"].(map[string]any) + + if session["reasoningEffort"] != "high" { + t.Errorf("expected reasoningEffort to remain 'high', got %s", session["reasoningEffort"]) + } + }) +} + +// Edge case tests for droid.go + +func TestDroidEdit_CorruptedJSON(t *testing.T) { + d := &Droid{} + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + settingsDir := filepath.Join(tmpDir, ".factory") + settingsPath := filepath.Join(settingsDir, "settings.json") + + os.MkdirAll(settingsDir, 0o755) + os.WriteFile(settingsPath, []byte(`{corrupted json content`), 0o644) + + // Corrupted JSON should return an error so user knows something is wrong + err := d.Edit([]string{"model-a"}) + if err == nil { + t.Fatal("expected error for corrupted JSON, got nil") + } + + // Original corrupted file should be preserved (not overwritten) + data, _ := os.ReadFile(settingsPath) + if string(data) != `{corrupted json content` { + t.Errorf("corrupted file was modified: got %s", string(data)) + } +} + +func TestDroidEdit_WrongTypeCustomModels(t *testing.T) { + d := &Droid{} + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + settingsDir := filepath.Join(tmpDir, ".factory") + settingsPath := filepath.Join(settingsDir, "settings.json") + + os.MkdirAll(settingsDir, 0o755) + // customModels is a string instead of array + os.WriteFile(settingsPath, []byte(`{"customModels": "not an array"}`), 0o644) + + // Should not panic - wrong type should be handled gracefully + err := d.Edit([]string{"model-a"}) + if err != nil { + t.Fatalf("Edit failed with wrong type customModels: %v", err) + } + + // Verify models were added correctly + data, _ := os.ReadFile(settingsPath) + var settings map[string]any + json.Unmarshal(data, &settings) + + customModels, ok := settings["customModels"].([]any) + if !ok { + t.Fatalf("customModels should be array after setup, got %T", settings["customModels"]) + } + if len(customModels) != 1 { + t.Errorf("expected 1 model, got %d", len(customModels)) + } +} + +func TestDroidEdit_EmptyModels(t *testing.T) { + d := &Droid{} + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + settingsDir := filepath.Join(tmpDir, ".factory") + settingsPath := filepath.Join(settingsDir, "settings.json") + + os.MkdirAll(settingsDir, 0o755) + originalContent := `{"customModels": [{"model": "existing"}]}` + os.WriteFile(settingsPath, []byte(originalContent), 0o644) + + // Empty models should be no-op + err := d.Edit([]string{}) + if err != nil { + t.Fatalf("Edit with empty models failed: %v", err) + } + + // Original content should be preserved (file not modified) + data, _ := os.ReadFile(settingsPath) + if string(data) != originalContent { + t.Errorf("empty models should not modify file, but content changed") + } +} + +func TestDroidEdit_DuplicateModels(t *testing.T) { + d := &Droid{} + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + settingsDir := filepath.Join(tmpDir, ".factory") + settingsPath := filepath.Join(settingsDir, "settings.json") + + // Add same model twice + err := d.Edit([]string{"model-a", "model-a"}) + if err != nil { + t.Fatalf("Edit with duplicates failed: %v", err) + } + + settings, err := readJSONFile(settingsPath) + if err != nil { + t.Fatalf("readJSONFile failed: %v", err) + } + + customModels, _ := settings["customModels"].([]any) + // Document current behavior: duplicates are kept as separate entries + if len(customModels) != 2 { + t.Logf("Note: duplicates result in %d entries (documenting behavior)", len(customModels)) + } +} + +func TestDroidEdit_MalformedModelEntry(t *testing.T) { + d := &Droid{} + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + settingsDir := filepath.Join(tmpDir, ".factory") + settingsPath := filepath.Join(settingsDir, "settings.json") + + os.MkdirAll(settingsDir, 0o755) + // Model entry is a string instead of a map + os.WriteFile(settingsPath, []byte(`{"customModels": ["not a map", 123]}`), 0o644) + + err := d.Edit([]string{"model-a"}) + if err != nil { + t.Fatalf("Edit with malformed entries failed: %v", err) + } + + // Malformed entries should be preserved in nonOllamaModels + settings, _ := readJSONFile(settingsPath) + customModels, _ := settings["customModels"].([]any) + + // Should have: 1 new Ollama model + 2 preserved malformed entries + if len(customModels) != 3 { + t.Errorf("expected 3 entries (1 new + 2 preserved malformed), got %d", len(customModels)) + } +} + +func TestDroidEdit_WrongTypeSessionSettings(t *testing.T) { + d := &Droid{} + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + settingsDir := filepath.Join(tmpDir, ".factory") + settingsPath := filepath.Join(settingsDir, "settings.json") + + os.MkdirAll(settingsDir, 0o755) + // sessionDefaultSettings is a string instead of map + os.WriteFile(settingsPath, []byte(`{"sessionDefaultSettings": "not a map"}`), 0o644) + + err := d.Edit([]string{"model-a"}) + if err != nil { + t.Fatalf("Edit with wrong type sessionDefaultSettings failed: %v", err) + } + + // Should create proper sessionDefaultSettings + settings, _ := readJSONFile(settingsPath) + session, ok := settings["sessionDefaultSettings"].(map[string]any) + if !ok { + t.Fatalf("sessionDefaultSettings should be map after setup, got %T", settings["sessionDefaultSettings"]) + } + if session["model"] == nil { + t.Error("expected model to be set in sessionDefaultSettings") + } +} + +func TestIsValidReasoningEffort(t *testing.T) { + tests := []struct { + effort string + valid bool + }{ + {"high", true}, + {"medium", true}, + {"low", true}, + {"none", true}, + {"off", false}, + {"", false}, + {"HIGH", false}, // case sensitive + {"max", false}, + } + + for _, tt := range tests { + t.Run(tt.effort, func(t *testing.T) { + got := isValidReasoningEffort(tt.effort) + if got != tt.valid { + t.Errorf("isValidReasoningEffort(%q) = %v, want %v", tt.effort, got, tt.valid) + } + }) + } +} diff --git a/cmd/config/files.go b/cmd/config/files.go new file mode 100644 index 000000000..545e25c4d --- /dev/null +++ b/cmd/config/files.go @@ -0,0 +1,99 @@ +package config + +import ( + "bytes" + "encoding/json" + "fmt" + "os" + "path/filepath" + "time" +) + +func readJSONFile(path string) (map[string]any, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + var result map[string]any + if err := json.Unmarshal(data, &result); err != nil { + return nil, err + } + return result, nil +} + +func copyFile(src, dst string) error { + info, err := os.Stat(src) + if err != nil { + return err + } + data, err := os.ReadFile(src) + if err != nil { + return err + } + return os.WriteFile(dst, data, info.Mode().Perm()) +} + +func backupDir() string { + return filepath.Join(os.TempDir(), "ollama-backups") +} + +func backupToTmp(srcPath string) (string, error) { + dir := backupDir() + if err := os.MkdirAll(dir, 0o755); err != nil { + return "", err + } + + backupPath := filepath.Join(dir, fmt.Sprintf("%s.%d", filepath.Base(srcPath), time.Now().Unix())) + if err := copyFile(srcPath, backupPath); err != nil { + return "", err + } + return backupPath, nil +} + +// writeWithBackup writes data to path via temp file + rename, backing up any existing file first +func writeWithBackup(path string, data []byte) error { + var backupPath string + // backup must be created before any writes to the target file + if existingContent, err := os.ReadFile(path); err == nil { + if !bytes.Equal(existingContent, data) { + backupPath, err = backupToTmp(path) + if err != nil { + return fmt.Errorf("backup failed: %w", err) + } + } + } else if !os.IsNotExist(err) { + return fmt.Errorf("read existing file: %w", err) + } + + dir := filepath.Dir(path) + tmp, err := os.CreateTemp(dir, ".tmp-*") + if err != nil { + return fmt.Errorf("create temp failed: %w", err) + } + tmpPath := tmp.Name() + + if _, err := tmp.Write(data); err != nil { + _ = tmp.Close() + _ = os.Remove(tmpPath) + return fmt.Errorf("write failed: %w", err) + } + if err := tmp.Sync(); err != nil { + _ = tmp.Close() + _ = os.Remove(tmpPath) + return fmt.Errorf("sync failed: %w", err) + } + if err := tmp.Close(); err != nil { + _ = os.Remove(tmpPath) + return fmt.Errorf("close failed: %w", err) + } + + if err := os.Rename(tmpPath, path); err != nil { + _ = os.Remove(tmpPath) + if backupPath != "" { + _ = copyFile(backupPath, path) + } + return fmt.Errorf("rename failed: %w", err) + } + + return nil +} diff --git a/cmd/config/files_test.go b/cmd/config/files_test.go new file mode 100644 index 000000000..e0aaea2b5 --- /dev/null +++ b/cmd/config/files_test.go @@ -0,0 +1,502 @@ +package config + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "runtime" + "testing" +) + +func mustMarshal(t *testing.T, v any) []byte { + t.Helper() + data, err := json.MarshalIndent(v, "", " ") + if err != nil { + t.Fatal(err) + } + return data +} + +func TestWriteWithBackup(t *testing.T) { + tmpDir := t.TempDir() + + t.Run("creates file", func(t *testing.T) { + path := filepath.Join(tmpDir, "new.json") + data := mustMarshal(t, map[string]string{"key": "value"}) + + if err := writeWithBackup(path, data); err != nil { + t.Fatal(err) + } + + content, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + + var result map[string]string + if err := json.Unmarshal(content, &result); err != nil { + t.Fatal(err) + } + if result["key"] != "value" { + t.Errorf("expected value, got %s", result["key"]) + } + }) + + t.Run("creates backup in /tmp/ollama-backups", func(t *testing.T) { + path := filepath.Join(tmpDir, "backup.json") + + os.WriteFile(path, []byte(`{"original": true}`), 0o644) + + data := mustMarshal(t, map[string]bool{"updated": true}) + if err := writeWithBackup(path, data); err != nil { + t.Fatal(err) + } + + entries, err := os.ReadDir(backupDir()) + if err != nil { + t.Fatal("backup directory not created") + } + + var foundBackup bool + for _, entry := range entries { + if filepath.Ext(entry.Name()) != ".json" { + name := entry.Name() + if len(name) > len("backup.json.") && name[:len("backup.json.")] == "backup.json." { + backupPath := filepath.Join(backupDir(), name) + backup, err := os.ReadFile(backupPath) + if err == nil { + var backupData map[string]bool + json.Unmarshal(backup, &backupData) + if backupData["original"] { + foundBackup = true + os.Remove(backupPath) + break + } + } + } + } + } + + if !foundBackup { + t.Error("backup file not created in /tmp/ollama-backups") + } + + current, _ := os.ReadFile(path) + var currentData map[string]bool + json.Unmarshal(current, ¤tData) + if !currentData["updated"] { + t.Error("file doesn't contain updated data") + } + }) + + t.Run("no backup for new file", func(t *testing.T) { + path := filepath.Join(tmpDir, "nobak.json") + + data := mustMarshal(t, map[string]string{"new": "file"}) + if err := writeWithBackup(path, data); err != nil { + t.Fatal(err) + } + + entries, _ := os.ReadDir(backupDir()) + for _, entry := range entries { + if len(entry.Name()) > len("nobak.json.") && entry.Name()[:len("nobak.json.")] == "nobak.json." { + t.Error("backup should not exist for new file") + } + } + }) + + t.Run("no backup when content unchanged", func(t *testing.T) { + path := filepath.Join(tmpDir, "unchanged.json") + + data := mustMarshal(t, map[string]string{"key": "value"}) + + if err := writeWithBackup(path, data); err != nil { + t.Fatal(err) + } + + entries1, _ := os.ReadDir(backupDir()) + countBefore := 0 + for _, e := range entries1 { + if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." { + countBefore++ + } + } + + if err := writeWithBackup(path, data); err != nil { + t.Fatal(err) + } + + entries2, _ := os.ReadDir(backupDir()) + countAfter := 0 + for _, e := range entries2 { + if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." { + countAfter++ + } + } + + if countAfter != countBefore { + t.Errorf("backup was created when content unchanged (before=%d, after=%d)", countBefore, countAfter) + } + }) + + t.Run("backup filename contains unix timestamp", func(t *testing.T) { + path := filepath.Join(tmpDir, "timestamped.json") + + os.WriteFile(path, []byte(`{"v": 1}`), 0o644) + data := mustMarshal(t, map[string]int{"v": 2}) + if err := writeWithBackup(path, data); err != nil { + t.Fatal(err) + } + + entries, _ := os.ReadDir(backupDir()) + var found bool + for _, entry := range entries { + name := entry.Name() + if len(name) > len("timestamped.json.") && name[:len("timestamped.json.")] == "timestamped.json." { + timestamp := name[len("timestamped.json."):] + for _, c := range timestamp { + if c < '0' || c > '9' { + t.Errorf("backup filename timestamp contains non-numeric character: %s", name) + } + } + found = true + os.Remove(filepath.Join(backupDir(), name)) + break + } + } + if !found { + t.Error("backup file with timestamp not found") + } + }) +} + +// Edge case tests for files.go + +// TestWriteWithBackup_FailsIfBackupFails documents critical behavior: if backup fails, we must not proceed. +// User could lose their config with no way to recover. +func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("permission tests unreliable on Windows") + } + + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "config.json") + + // Create original file + originalContent := []byte(`{"original": true}`) + os.WriteFile(path, originalContent, 0o644) + + // Make backup directory read-only to force backup failure + backupDir := backupDir() + os.MkdirAll(backupDir, 0o755) + os.Chmod(backupDir, 0o444) // Read-only + defer os.Chmod(backupDir, 0o755) + + newContent := []byte(`{"updated": true}`) + err := writeWithBackup(path, newContent) + + // Should fail because backup couldn't be created + if err == nil { + t.Error("expected error when backup fails, got nil") + } + + // Original file should be preserved + current, _ := os.ReadFile(path) + if string(current) != string(originalContent) { + t.Errorf("original file was modified despite backup failure: got %s", string(current)) + } +} + +// TestWriteWithBackup_PermissionDenied verifies clear error when target file has wrong permissions. +// Common issue when config owned by root or wrong perms. +func TestWriteWithBackup_PermissionDenied(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("permission tests unreliable on Windows") + } + + tmpDir := t.TempDir() + + // Create a read-only directory + readOnlyDir := filepath.Join(tmpDir, "readonly") + os.MkdirAll(readOnlyDir, 0o755) + os.Chmod(readOnlyDir, 0o444) + defer os.Chmod(readOnlyDir, 0o755) + + path := filepath.Join(readOnlyDir, "config.json") + err := writeWithBackup(path, []byte(`{"test": true}`)) + + if err == nil { + t.Error("expected permission error, got nil") + } +} + +// TestWriteWithBackup_DirectoryDoesNotExist verifies behavior when target directory doesn't exist. +// writeWithBackup doesn't create directories - caller is responsible. +func TestWriteWithBackup_DirectoryDoesNotExist(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "nonexistent", "subdir", "config.json") + + err := writeWithBackup(path, []byte(`{"test": true}`)) + + // Should fail because directory doesn't exist + if err == nil { + t.Error("expected error for nonexistent directory, got nil") + } +} + +// TestWriteWithBackup_SymlinkTarget documents behavior when target is a symlink. +// Documents what happens if user symlinks their config file. +func TestWriteWithBackup_SymlinkTarget(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("symlink tests may require admin on Windows") + } + + tmpDir := t.TempDir() + realFile := filepath.Join(tmpDir, "real.json") + symlink := filepath.Join(tmpDir, "link.json") + + // Create real file and symlink + os.WriteFile(realFile, []byte(`{"v": 1}`), 0o644) + os.Symlink(realFile, symlink) + + // Write through symlink + err := writeWithBackup(symlink, []byte(`{"v": 2}`)) + if err != nil { + t.Fatalf("writeWithBackup through symlink failed: %v", err) + } + + // The real file should be updated (symlink followed for temp file creation) + content, _ := os.ReadFile(symlink) + if string(content) != `{"v": 2}` { + t.Errorf("symlink target not updated correctly: got %s", string(content)) + } +} + +// TestBackupToTmp_SpecialCharsInFilename verifies backup works with special characters. +// User may have config files with unusual names. +func TestBackupToTmp_SpecialCharsInFilename(t *testing.T) { + tmpDir := t.TempDir() + + // File with spaces and special chars + path := filepath.Join(tmpDir, "my config (backup).json") + os.WriteFile(path, []byte(`{"test": true}`), 0o644) + + backupPath, err := backupToTmp(path) + if err != nil { + t.Fatalf("backupToTmp with special chars failed: %v", err) + } + + // Verify backup exists and has correct content + content, err := os.ReadFile(backupPath) + if err != nil { + t.Fatalf("could not read backup: %v", err) + } + if string(content) != `{"test": true}` { + t.Errorf("backup content mismatch: got %s", string(content)) + } + + os.Remove(backupPath) +} + +// TestCopyFile_PreservesPermissions verifies that copyFile preserves file permissions. +func TestCopyFile_PreservesPermissions(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("permission preservation tests unreliable on Windows") + } + + tmpDir := t.TempDir() + src := filepath.Join(tmpDir, "src.json") + dst := filepath.Join(tmpDir, "dst.json") + + // Create source with specific permissions + os.WriteFile(src, []byte(`{"test": true}`), 0o600) + + err := copyFile(src, dst) + if err != nil { + t.Fatalf("copyFile failed: %v", err) + } + + srcInfo, _ := os.Stat(src) + dstInfo, _ := os.Stat(dst) + + if srcInfo.Mode().Perm() != dstInfo.Mode().Perm() { + t.Errorf("permissions not preserved: src=%v, dst=%v", srcInfo.Mode().Perm(), dstInfo.Mode().Perm()) + } +} + +// TestCopyFile_SourceNotFound verifies clear error when source doesn't exist. +func TestCopyFile_SourceNotFound(t *testing.T) { + tmpDir := t.TempDir() + src := filepath.Join(tmpDir, "nonexistent.json") + dst := filepath.Join(tmpDir, "dst.json") + + err := copyFile(src, dst) + if err == nil { + t.Error("expected error for nonexistent source, got nil") + } +} + +// TestWriteWithBackup_TargetIsDirectory verifies error when path points to a directory. +func TestWriteWithBackup_TargetIsDirectory(t *testing.T) { + tmpDir := t.TempDir() + dirPath := filepath.Join(tmpDir, "actualdir") + os.MkdirAll(dirPath, 0o755) + + err := writeWithBackup(dirPath, []byte(`{"test": true}`)) + if err == nil { + t.Error("expected error when target is a directory, got nil") + } +} + +// TestWriteWithBackup_EmptyData verifies writing zero bytes works correctly. +func TestWriteWithBackup_EmptyData(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "empty.json") + + err := writeWithBackup(path, []byte{}) + if err != nil { + t.Fatalf("writeWithBackup with empty data failed: %v", err) + } + + content, err := os.ReadFile(path) + if err != nil { + t.Fatalf("could not read file: %v", err) + } + if len(content) != 0 { + t.Errorf("expected empty file, got %d bytes", len(content)) + } +} + +// TestWriteWithBackup_FileUnreadableButDirWritable verifies behavior when existing file +// cannot be read (for backup comparison) but directory is writable. +func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("permission tests unreliable on Windows") + } + + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "unreadable.json") + + // Create file and make it unreadable + os.WriteFile(path, []byte(`{"original": true}`), 0o644) + os.Chmod(path, 0o000) + defer os.Chmod(path, 0o644) + + // Should fail because we can't read the file to compare/backup + err := writeWithBackup(path, []byte(`{"updated": true}`)) + if err == nil { + t.Error("expected error when file is unreadable, got nil") + } +} + +// TestWriteWithBackup_RapidSuccessiveWrites verifies backup works with multiple writes +// within the same second (timestamp collision scenario). +func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "rapid.json") + + // Create initial file + os.WriteFile(path, []byte(`{"v": 0}`), 0o644) + + // Rapid successive writes + for i := 1; i <= 3; i++ { + data := []byte(fmt.Sprintf(`{"v": %d}`, i)) + if err := writeWithBackup(path, data); err != nil { + t.Fatalf("write %d failed: %v", i, err) + } + } + + // Verify final content + content, _ := os.ReadFile(path) + if string(content) != `{"v": 3}` { + t.Errorf("expected final content {\"v\": 3}, got %s", string(content)) + } + + // Verify at least one backup exists + entries, _ := os.ReadDir(backupDir()) + var backupCount int + for _, e := range entries { + if len(e.Name()) > len("rapid.json.") && e.Name()[:len("rapid.json.")] == "rapid.json." { + backupCount++ + } + } + if backupCount == 0 { + t.Error("expected at least one backup file from rapid writes") + } +} + +// TestWriteWithBackup_BackupDirIsFile verifies error when backup directory path is a file. +func TestWriteWithBackup_BackupDirIsFile(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("test modifies system temp directory") + } + + // Create a file at the backup directory path + backupPath := backupDir() + // Clean up any existing directory first + os.RemoveAll(backupPath) + // Create a file instead of directory + os.WriteFile(backupPath, []byte("not a directory"), 0o644) + defer func() { + os.Remove(backupPath) + os.MkdirAll(backupPath, 0o755) + }() + + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "test.json") + os.WriteFile(path, []byte(`{"original": true}`), 0o644) + + err := writeWithBackup(path, []byte(`{"updated": true}`)) + if err == nil { + t.Error("expected error when backup dir is a file, got nil") + } +} + +// TestWriteWithBackup_NoOrphanTempFiles verifies temp files are cleaned up on failure. +func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("permission tests unreliable on Windows") + } + + tmpDir := t.TempDir() + + // Count existing temp files + countTempFiles := func() int { + entries, _ := os.ReadDir(tmpDir) + count := 0 + for _, e := range entries { + if len(e.Name()) > 4 && e.Name()[:4] == ".tmp" { + count++ + } + } + return count + } + + before := countTempFiles() + + // Create a file, then make directory read-only to cause rename failure + path := filepath.Join(tmpDir, "orphan.json") + os.WriteFile(path, []byte(`{"v": 1}`), 0o644) + + // Make a subdirectory and try to write there after making parent read-only + subDir := filepath.Join(tmpDir, "subdir") + os.MkdirAll(subDir, 0o755) + subPath := filepath.Join(subDir, "config.json") + os.WriteFile(subPath, []byte(`{"v": 1}`), 0o644) + + // Make subdir read-only after creating temp file would succeed but rename would fail + // This is tricky to test - the temp file is created in the same dir, so if we can't + // rename, we also couldn't create. Let's just verify normal failure cleanup works. + + // Force a failure by making the target a directory + badPath := filepath.Join(tmpDir, "isdir") + os.MkdirAll(badPath, 0o755) + + _ = writeWithBackup(badPath, []byte(`{"test": true}`)) + + after := countTempFiles() + if after > before { + t.Errorf("orphan temp files left behind: before=%d, after=%d", before, after) + } +} diff --git a/cmd/config/integrations.go b/cmd/config/integrations.go new file mode 100644 index 000000000..e897317d0 --- /dev/null +++ b/cmd/config/integrations.go @@ -0,0 +1,361 @@ +package config + +import ( + "context" + "errors" + "fmt" + "maps" + "os" + "os/exec" + "runtime" + "slices" + "strings" + "time" + + "github.com/ollama/ollama/api" + "github.com/spf13/cobra" +) + +// Runners execute the launching of a model with the integration - claude, codex +// Editors can edit config files (supports multi-model selection) - opencode, droid +// They are composable interfaces where in some cases an editor is also a runner - opencode, droid +// Runner can run an integration with a model. + +type Runner interface { + Run(model string) error + // String returns the human-readable name of the integration + String() string +} + +// Editor can edit config files (supports multi-model selection) +type Editor interface { + // Paths returns the paths to the config files for the integration + Paths() []string + // Edit updates the config files for the integration with the given models + Edit(models []string) error + // Models returns the models currently configured for the integration + Models() []string +} + +// integrations is the registry of available integrations. +var integrations = map[string]Runner{ + "claude": &Claude{}, + "codex": &Codex{}, + "droid": &Droid{}, + "opencode": &OpenCode{}, +} + +func selectIntegration() (string, error) { + if len(integrations) == 0 { + return "", fmt.Errorf("no integrations available") + } + + names := slices.Sorted(maps.Keys(integrations)) + var items []selectItem + for _, name := range names { + r := integrations[name] + description := r.String() + if conn, err := loadIntegration(name); err == nil && len(conn.Models) > 0 { + description = fmt.Sprintf("%s (%s)", r.String(), conn.Models[0]) + } + items = append(items, selectItem{Name: name, Description: description}) + } + + return selectPrompt("Select integration:", items) +} + +// selectModels lets the user select models for an integration +func selectModels(ctx context.Context, name, current string) ([]string, error) { + r, ok := integrations[name] + if !ok { + return nil, fmt.Errorf("unknown integration: %s", name) + } + + client, err := api.ClientFromEnvironment() + if err != nil { + return nil, err + } + + models, err := client.List(ctx) + if err != nil { + return nil, err + } + + if len(models.Models) == 0 { + return nil, fmt.Errorf("no models available, run 'ollama pull ' first") + } + + var items []selectItem + cloudModels := make(map[string]bool) + for _, m := range models.Models { + if m.RemoteModel != "" { + cloudModels[m.Name] = true + } + items = append(items, selectItem{Name: m.Name}) + } + + if len(items) == 0 { + return nil, fmt.Errorf("no local models available, run 'ollama pull ' first") + } + + // Get previously configured models (saved config takes precedence) + var preChecked []string + if saved, err := loadIntegration(name); err == nil { + preChecked = saved.Models + } else if editor, ok := r.(Editor); ok { + preChecked = editor.Models() + } + checked := make(map[string]bool, len(preChecked)) + for _, n := range preChecked { + checked[n] = true + } + + // Resolve current to full name (e.g., "llama3.2" -> "llama3.2:latest") + for _, item := range items { + if item.Name == current || strings.HasPrefix(item.Name, current+":") { + current = item.Name + break + } + } + + // If current model is configured, move to front of preChecked + if checked[current] { + preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...) + } + + // Sort: checked first, then alphabetical + slices.SortFunc(items, func(a, b selectItem) int { + ac, bc := checked[a.Name], checked[b.Name] + if ac != bc { + if ac { + return -1 + } + return 1 + } + return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name)) + }) + + var selected []string + // only editors support multi-model selection + if _, ok := r.(Editor); ok { + selected, err = multiSelectPrompt(fmt.Sprintf("Select models for %s:", r), items, preChecked) + if err != nil { + return nil, err + } + } else { + model, err := selectPrompt(fmt.Sprintf("Select model for %s:", r), items) + if err != nil { + return nil, err + } + selected = []string{model} + } + + // if any model in selected is a cloud model, ensure signed in + var selectedCloudModels []string + for _, m := range selected { + if cloudModels[m] { + selectedCloudModels = append(selectedCloudModels, m) + } + } + if len(selectedCloudModels) > 0 { + // ensure user is signed in + user, err := client.Whoami(ctx) + if err == nil && user != nil && user.Name != "" { + return selected, nil + } + + var aErr api.AuthorizationError + if !errors.As(err, &aErr) || aErr.SigninURL == "" { + return nil, err + } + + modelList := strings.Join(selectedCloudModels, ", ") + yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList)) + if err != nil || !yes { + return nil, fmt.Errorf("%s requires sign in", modelList) + } + + fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL) + + // TODO(parthsareen): extract into auth package for cmd + // Auto-open browser (best effort, fail silently) + switch runtime.GOOS { + case "darwin": + _ = exec.Command("open", aErr.SigninURL).Start() + case "linux": + _ = exec.Command("xdg-open", aErr.SigninURL).Start() + case "windows": + _ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start() + } + + spinnerFrames := []string{"|", "/", "-", "\\"} + frame := 0 + + fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0]) + + ticker := time.NewTicker(200 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + fmt.Fprintf(os.Stderr, "\r\033[K") + return nil, ctx.Err() + case <-ticker.C: + frame++ + fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)]) + + // poll every 10th frame (~2 seconds) + if frame%10 == 0 { + u, err := client.Whoami(ctx) + if err == nil && u != nil && u.Name != "" { + fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name) + return selected, nil + } + } + } + } + } + + return selected, nil +} + +func runIntegration(name, modelName string) error { + r, ok := integrations[name] + if !ok { + return fmt.Errorf("unknown integration: %s", name) + } + fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName) + return r.Run(modelName) +} + +// ConfigCmd returns the cobra command for configuring integrations. +func ConfigCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error) *cobra.Command { + var modelFlag string + var launchFlag bool + + cmd := &cobra.Command{ + Use: "config [INTEGRATION]", + Short: "Configure an external integration to use Ollama", + Long: `Configure an external application to use Ollama models. + +Supported integrations: + claude Claude Code + codex Codex + droid Droid + opencode OpenCode + +Examples: + ollama config + ollama config claude + ollama config droid --launch`, + Args: cobra.MaximumNArgs(1), + PreRunE: checkServerHeartbeat, + RunE: func(cmd *cobra.Command, args []string) error { + var name string + if len(args) > 0 { + name = args[0] + } else { + var err error + name, err = selectIntegration() + if errors.Is(err, errCancelled) { + return nil + } + if err != nil { + return err + } + } + + r, ok := integrations[strings.ToLower(name)] + if !ok { + return fmt.Errorf("unknown integration: %s", name) + } + + // If --launch without --model, use saved config if available + if launchFlag && modelFlag == "" { + if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 { + return runIntegration(name, config.Models[0]) + } + } + + var models []string + if modelFlag != "" { + // When --model is specified, merge with existing models (new model becomes default) + models = []string{modelFlag} + if existing, err := loadIntegration(name); err == nil && len(existing.Models) > 0 { + for _, m := range existing.Models { + if m != modelFlag { + models = append(models, m) + } + } + } + } else { + var err error + models, err = selectModels(cmd.Context(), name, "") + if errors.Is(err, errCancelled) { + return nil + } + if err != nil { + return err + } + } + + if editor, isEditor := r.(Editor); isEditor { + paths := editor.Paths() + if len(paths) > 0 { + fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", r) + for _, p := range paths { + fmt.Fprintf(os.Stderr, " %s\n", p) + } + fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", backupDir()) + + if ok, _ := confirmPrompt("Proceed?"); !ok { + return nil + } + } + } + + if err := saveIntegration(name, models); err != nil { + return fmt.Errorf("failed to save: %w", err) + } + + if editor, isEditor := r.(Editor); isEditor { + if err := editor.Edit(models); err != nil { + return fmt.Errorf("setup failed: %w", err) + } + } + + if _, isEditor := r.(Editor); isEditor { + if len(models) == 1 { + fmt.Fprintf(os.Stderr, "Added %s to %s\n", models[0], r) + } else { + fmt.Fprintf(os.Stderr, "Added %d models to %s (default: %s)\n", len(models), r, models[0]) + } + } + + if slices.ContainsFunc(models, func(m string) bool { + return !strings.HasSuffix(m, "cloud") + }) { + fmt.Fprintln(os.Stderr) + fmt.Fprintln(os.Stderr, "Coding agents work best with at least 64k context. Either:") + fmt.Fprintln(os.Stderr, " - Set the context slider in Ollama app settings") + fmt.Fprintln(os.Stderr, " - Run: OLLAMA_CONTEXT_LENGTH=64000 ollama serve") + } + + if launchFlag { + return runIntegration(name, models[0]) + } + + if launch, _ := confirmPrompt(fmt.Sprintf("\nLaunch %s now?", r)); launch { + return runIntegration(name, models[0]) + } + + fmt.Fprintf(os.Stderr, "Run 'ollama config %s --launch' to start with %s\n", strings.ToLower(name), models[0]) + return nil + }, + } + + cmd.Flags().StringVar(&modelFlag, "model", "", "Model to use") + cmd.Flags().BoolVar(&launchFlag, "launch", false, "Launch the integration after configuring") + return cmd +} diff --git a/cmd/config/integrations_test.go b/cmd/config/integrations_test.go new file mode 100644 index 000000000..d8ad10ce2 --- /dev/null +++ b/cmd/config/integrations_test.go @@ -0,0 +1,188 @@ +package config + +import ( + "slices" + "strings" + "testing" + + "github.com/spf13/cobra" +) + +func TestIntegrationLookup(t *testing.T) { + tests := []struct { + name string + input string + wantFound bool + wantName string + }{ + {"claude lowercase", "claude", true, "Claude Code"}, + {"claude uppercase", "CLAUDE", true, "Claude Code"}, + {"claude mixed case", "Claude", true, "Claude Code"}, + {"codex", "codex", true, "Codex"}, + {"droid", "droid", true, "Droid"}, + {"opencode", "opencode", true, "OpenCode"}, + {"unknown integration", "unknown", false, ""}, + {"empty string", "", false, ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r, found := integrations[strings.ToLower(tt.input)] + if found != tt.wantFound { + t.Errorf("integrations[%q] found = %v, want %v", tt.input, found, tt.wantFound) + } + if found && r.String() != tt.wantName { + t.Errorf("integrations[%q].String() = %q, want %q", tt.input, r.String(), tt.wantName) + } + }) + } +} + +func TestIntegrationRegistry(t *testing.T) { + expectedIntegrations := []string{"claude", "codex", "droid", "opencode"} + + for _, name := range expectedIntegrations { + t.Run(name, func(t *testing.T) { + r, ok := integrations[name] + if !ok { + t.Fatalf("integration %q not found in registry", name) + } + if r.String() == "" { + t.Error("integration.String() should not be empty") + } + }) + } +} + +func TestHasLocalModel(t *testing.T) { + tests := []struct { + name string + models []string + want bool + }{ + {"empty list", []string{}, false}, + {"single local model", []string{"llama3.2"}, true}, + {"single cloud model", []string{"cloud-model"}, false}, + {"mixed models", []string{"cloud-model", "llama3.2"}, true}, + {"multiple local models", []string{"llama3.2", "qwen2.5"}, true}, + {"multiple cloud models", []string{"cloud-a", "cloud-b"}, false}, + {"local model first", []string{"llama3.2", "cloud-model"}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := slices.ContainsFunc(tt.models, func(m string) bool { + return !strings.Contains(m, "cloud") + }) + if got != tt.want { + t.Errorf("hasLocalModel(%v) = %v, want %v", tt.models, got, tt.want) + } + }) + } +} + +func TestConfigCmd(t *testing.T) { + // Mock checkServerHeartbeat that always succeeds + mockCheck := func(cmd *cobra.Command, args []string) error { + return nil + } + + cmd := ConfigCmd(mockCheck) + + t.Run("command structure", func(t *testing.T) { + if cmd.Use != "config [INTEGRATION]" { + t.Errorf("Use = %q, want %q", cmd.Use, "config [INTEGRATION]") + } + if cmd.Short == "" { + t.Error("Short description should not be empty") + } + if cmd.Long == "" { + t.Error("Long description should not be empty") + } + }) + + t.Run("flags exist", func(t *testing.T) { + modelFlag := cmd.Flags().Lookup("model") + if modelFlag == nil { + t.Error("--model flag should exist") + } + + launchFlag := cmd.Flags().Lookup("launch") + if launchFlag == nil { + t.Error("--launch flag should exist") + } + }) + + t.Run("PreRunE is set", func(t *testing.T) { + if cmd.PreRunE == nil { + t.Error("PreRunE should be set to checkServerHeartbeat") + } + }) +} + +func TestRunIntegration_UnknownIntegration(t *testing.T) { + err := runIntegration("unknown-integration", "model") + if err == nil { + t.Error("expected error for unknown integration, got nil") + } + if !strings.Contains(err.Error(), "unknown integration") { + t.Errorf("error should mention 'unknown integration', got: %v", err) + } +} + +func TestHasLocalModel_DocumentsHeuristic(t *testing.T) { + tests := []struct { + name string + models []string + want bool + reason string + }{ + {"empty list", []string{}, false, "empty list has no local models"}, + {"contains-cloud-substring", []string{"deepseek-r1:cloud"}, false, "model with 'cloud' substring is considered cloud"}, + {"cloud-in-name", []string{"my-cloud-model"}, false, "'cloud' anywhere in name = cloud model"}, + {"cloudless", []string{"cloudless-model"}, false, "'cloudless' still contains 'cloud'"}, + {"local-model", []string{"llama3.2"}, true, "no 'cloud' = local"}, + {"mixed", []string{"cloud-model", "llama3.2"}, true, "one local model = hasLocalModel true"}, + {"all-cloud", []string{"cloud-a", "cloud-b"}, false, "all contain 'cloud'"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := slices.ContainsFunc(tt.models, func(m string) bool { + return !strings.Contains(m, "cloud") + }) + if got != tt.want { + t.Errorf("hasLocalModel(%v) = %v, want %v (%s)", tt.models, got, tt.want, tt.reason) + } + }) + } +} + +func TestConfigCmd_NilHeartbeat(t *testing.T) { + // This should not panic - cmd creation should work even with nil + cmd := ConfigCmd(nil) + if cmd == nil { + t.Fatal("ConfigCmd returned nil") + } + + // PreRunE should be nil when passed nil + if cmd.PreRunE != nil { + t.Log("Note: PreRunE is set even when nil is passed (acceptable)") + } +} + +func TestAllIntegrations_HaveRequiredMethods(t *testing.T) { + for name, r := range integrations { + t.Run(name, func(t *testing.T) { + // Test String() doesn't panic and returns non-empty + displayName := r.String() + if displayName == "" { + t.Error("String() should not return empty") + } + + // Test Run() exists (we can't call it without actually running the command) + // Just verify the method is available + var _ func(string) error = r.Run + }) + } +} diff --git a/cmd/config/opencode.go b/cmd/config/opencode.go new file mode 100644 index 000000000..6ccf45145 --- /dev/null +++ b/cmd/config/opencode.go @@ -0,0 +1,203 @@ +package config + +import ( + "encoding/json" + "fmt" + "maps" + "os" + "os/exec" + "path/filepath" + "slices" + "strings" +) + +// OpenCode implements Runner and Editor for OpenCode integration +type OpenCode struct{} + +func (o *OpenCode) String() string { return "OpenCode" } + +func (o *OpenCode) Run(model string) error { + if _, err := exec.LookPath("opencode"); err != nil { + return fmt.Errorf("opencode is not installed, install from https://opencode.ai") + } + + // Call Edit() to ensure config is up-to-date before launch + models := []string{model} + if config, err := loadIntegration("opencode"); err == nil && len(config.Models) > 0 { + models = config.Models + } + if err := o.Edit(models); err != nil { + return fmt.Errorf("setup failed: %w", err) + } + + cmd := exec.Command("opencode") + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + return cmd.Run() +} + +func (o *OpenCode) Paths() []string { + home, err := os.UserHomeDir() + if err != nil { + return nil + } + + var paths []string + p := filepath.Join(home, ".config", "opencode", "opencode.json") + if _, err := os.Stat(p); err == nil { + paths = append(paths, p) + } + sp := filepath.Join(home, ".local", "state", "opencode", "model.json") + if _, err := os.Stat(sp); err == nil { + paths = append(paths, sp) + } + return paths +} + +func (o *OpenCode) Edit(modelList []string) error { + if len(modelList) == 0 { + return nil + } + + home, err := os.UserHomeDir() + if err != nil { + return err + } + + configPath := filepath.Join(home, ".config", "opencode", "opencode.json") + if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil { + return err + } + + config := make(map[string]any) + if data, err := os.ReadFile(configPath); err == nil { + _ = json.Unmarshal(data, &config) // Ignore parse errors; treat missing/corrupt files as empty + } + + config["$schema"] = "https://opencode.ai/config.json" + + provider, ok := config["provider"].(map[string]any) + if !ok { + provider = make(map[string]any) + } + + ollama, ok := provider["ollama"].(map[string]any) + if !ok { + ollama = map[string]any{ + "npm": "@ai-sdk/openai-compatible", + "name": "Ollama (local)", + "options": map[string]any{ + "baseURL": "http://localhost:11434/v1", + }, + } + } + + models, ok := ollama["models"].(map[string]any) + if !ok { + models = make(map[string]any) + } + + selectedSet := make(map[string]bool) + for _, m := range modelList { + selectedSet[m] = true + } + + for name, cfg := range models { + if cfgMap, ok := cfg.(map[string]any); ok { + if displayName, ok := cfgMap["name"].(string); ok { + if strings.HasSuffix(displayName, "[Ollama]") && !selectedSet[name] { + delete(models, name) + } + } + } + } + + for _, model := range modelList { + models[model] = map[string]any{ + "name": fmt.Sprintf("%s [Ollama]", model), + } + } + + ollama["models"] = models + provider["ollama"] = ollama + config["provider"] = provider + + configData, err := json.MarshalIndent(config, "", " ") + if err != nil { + return err + } + if err := writeWithBackup(configPath, configData); err != nil { + return err + } + + statePath := filepath.Join(home, ".local", "state", "opencode", "model.json") + if err := os.MkdirAll(filepath.Dir(statePath), 0o755); err != nil { + return err + } + + state := map[string]any{ + "recent": []any{}, + "favorite": []any{}, + "variant": map[string]any{}, + } + if data, err := os.ReadFile(statePath); err == nil { + _ = json.Unmarshal(data, &state) // Ignore parse errors; use defaults + } + + recent, _ := state["recent"].([]any) + + modelSet := make(map[string]bool) + for _, m := range modelList { + modelSet[m] = true + } + + // Filter out existing Ollama models we're about to re-add + newRecent := slices.DeleteFunc(slices.Clone(recent), func(entry any) bool { + e, ok := entry.(map[string]any) + if !ok || e["providerID"] != "ollama" { + return false + } + modelID, _ := e["modelID"].(string) + return modelSet[modelID] + }) + + // Prepend models in reverse order so first model ends up first + for _, model := range slices.Backward(modelList) { + newRecent = slices.Insert(newRecent, 0, any(map[string]any{ + "providerID": "ollama", + "modelID": model, + })) + } + + const maxRecentModels = 10 + newRecent = newRecent[:min(len(newRecent), maxRecentModels)] + + state["recent"] = newRecent + + stateData, err := json.MarshalIndent(state, "", " ") + if err != nil { + return err + } + return writeWithBackup(statePath, stateData) +} + +func (o *OpenCode) Models() []string { + home, err := os.UserHomeDir() + if err != nil { + return nil + } + config, err := readJSONFile(filepath.Join(home, ".config", "opencode", "opencode.json")) + if err != nil { + return nil + } + provider, _ := config["provider"].(map[string]any) + ollama, _ := provider["ollama"].(map[string]any) + models, _ := ollama["models"].(map[string]any) + if len(models) == 0 { + return nil + } + keys := slices.Collect(maps.Keys(models)) + slices.Sort(keys) + return keys +} diff --git a/cmd/config/opencode_test.go b/cmd/config/opencode_test.go new file mode 100644 index 000000000..524dfc59e --- /dev/null +++ b/cmd/config/opencode_test.go @@ -0,0 +1,437 @@ +package config + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" +) + +func TestOpenCodeIntegration(t *testing.T) { + o := &OpenCode{} + + t.Run("String", func(t *testing.T) { + if got := o.String(); got != "OpenCode" { + t.Errorf("String() = %q, want %q", got, "OpenCode") + } + }) + + t.Run("implements Runner", func(t *testing.T) { + var _ Runner = o + }) + + t.Run("implements Editor", func(t *testing.T) { + var _ Editor = o + }) +} + +func TestOpenCodeEdit(t *testing.T) { + o := &OpenCode{} + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + configDir := filepath.Join(tmpDir, ".config", "opencode") + configPath := filepath.Join(configDir, "opencode.json") + stateDir := filepath.Join(tmpDir, ".local", "state", "opencode") + statePath := filepath.Join(stateDir, "model.json") + + cleanup := func() { + os.RemoveAll(configDir) + os.RemoveAll(stateDir) + } + + t.Run("fresh install", func(t *testing.T) { + cleanup() + if err := o.Edit([]string{"llama3.2"}); err != nil { + t.Fatal(err) + } + assertOpenCodeModelExists(t, configPath, "llama3.2") + assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2") + }) + + t.Run("preserve other providers", func(t *testing.T) { + cleanup() + os.MkdirAll(configDir, 0o755) + os.WriteFile(configPath, []byte(`{"provider":{"anthropic":{"apiKey":"xxx"}}}`), 0o644) + if err := o.Edit([]string{"llama3.2"}); err != nil { + t.Fatal(err) + } + data, _ := os.ReadFile(configPath) + var cfg map[string]any + json.Unmarshal(data, &cfg) + provider := cfg["provider"].(map[string]any) + if provider["anthropic"] == nil { + t.Error("anthropic provider was removed") + } + assertOpenCodeModelExists(t, configPath, "llama3.2") + }) + + t.Run("preserve other models", func(t *testing.T) { + cleanup() + os.MkdirAll(configDir, 0o755) + os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"mistral":{"name":"Mistral"}}}}}`), 0o644) + if err := o.Edit([]string{"llama3.2"}); err != nil { + t.Fatal(err) + } + assertOpenCodeModelExists(t, configPath, "mistral") + assertOpenCodeModelExists(t, configPath, "llama3.2") + }) + + t.Run("update existing model", func(t *testing.T) { + cleanup() + o.Edit([]string{"llama3.2"}) + o.Edit([]string{"llama3.2"}) + assertOpenCodeModelExists(t, configPath, "llama3.2") + }) + + t.Run("preserve top-level keys", func(t *testing.T) { + cleanup() + os.MkdirAll(configDir, 0o755) + os.WriteFile(configPath, []byte(`{"theme":"dark","keybindings":{}}`), 0o644) + if err := o.Edit([]string{"llama3.2"}); err != nil { + t.Fatal(err) + } + data, _ := os.ReadFile(configPath) + var cfg map[string]any + json.Unmarshal(data, &cfg) + if cfg["theme"] != "dark" { + t.Error("theme was removed") + } + if cfg["keybindings"] == nil { + t.Error("keybindings was removed") + } + }) + + t.Run("model state - insert at index 0", func(t *testing.T) { + cleanup() + os.MkdirAll(stateDir, 0o755) + os.WriteFile(statePath, []byte(`{"recent":[{"providerID":"anthropic","modelID":"claude"}],"favorite":[],"variant":{}}`), 0o644) + if err := o.Edit([]string{"llama3.2"}); err != nil { + t.Fatal(err) + } + assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2") + assertOpenCodeRecentModel(t, statePath, 1, "anthropic", "claude") + }) + + t.Run("model state - preserve favorites and variants", func(t *testing.T) { + cleanup() + os.MkdirAll(stateDir, 0o755) + os.WriteFile(statePath, []byte(`{"recent":[],"favorite":[{"providerID":"x","modelID":"y"}],"variant":{"a":"b"}}`), 0o644) + if err := o.Edit([]string{"llama3.2"}); err != nil { + t.Fatal(err) + } + data, _ := os.ReadFile(statePath) + var state map[string]any + json.Unmarshal(data, &state) + if len(state["favorite"].([]any)) != 1 { + t.Error("favorite was modified") + } + if state["variant"].(map[string]any)["a"] != "b" { + t.Error("variant was modified") + } + }) + + t.Run("model state - deduplicate on re-add", func(t *testing.T) { + cleanup() + os.MkdirAll(stateDir, 0o755) + os.WriteFile(statePath, []byte(`{"recent":[{"providerID":"ollama","modelID":"llama3.2"},{"providerID":"anthropic","modelID":"claude"}],"favorite":[],"variant":{}}`), 0o644) + if err := o.Edit([]string{"llama3.2"}); err != nil { + t.Fatal(err) + } + data, _ := os.ReadFile(statePath) + var state map[string]any + json.Unmarshal(data, &state) + recent := state["recent"].([]any) + if len(recent) != 2 { + t.Errorf("expected 2 recent entries, got %d", len(recent)) + } + assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2") + }) + + t.Run("remove model", func(t *testing.T) { + cleanup() + // First add two models + o.Edit([]string{"llama3.2", "mistral"}) + assertOpenCodeModelExists(t, configPath, "llama3.2") + assertOpenCodeModelExists(t, configPath, "mistral") + + // Then remove one by only selecting the other + o.Edit([]string{"llama3.2"}) + assertOpenCodeModelExists(t, configPath, "llama3.2") + assertOpenCodeModelNotExists(t, configPath, "mistral") + }) + + t.Run("remove model preserves non-ollama models", func(t *testing.T) { + cleanup() + os.MkdirAll(configDir, 0o755) + // Add a non-Ollama model manually + os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"external":{"name":"External Model"}}}}}`), 0o644) + + o.Edit([]string{"llama3.2"}) + assertOpenCodeModelExists(t, configPath, "llama3.2") + assertOpenCodeModelExists(t, configPath, "external") // Should be preserved + }) +} + +func assertOpenCodeModelExists(t *testing.T, path, model string) { + t.Helper() + data, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + var cfg map[string]any + if err := json.Unmarshal(data, &cfg); err != nil { + t.Fatal(err) + } + provider, ok := cfg["provider"].(map[string]any) + if !ok { + t.Fatal("provider not found") + } + ollama, ok := provider["ollama"].(map[string]any) + if !ok { + t.Fatal("ollama provider not found") + } + models, ok := ollama["models"].(map[string]any) + if !ok { + t.Fatal("models not found") + } + if models[model] == nil { + t.Errorf("model %s not found", model) + } +} + +func assertOpenCodeModelNotExists(t *testing.T, path, model string) { + t.Helper() + data, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + var cfg map[string]any + if err := json.Unmarshal(data, &cfg); err != nil { + t.Fatal(err) + } + provider, ok := cfg["provider"].(map[string]any) + if !ok { + return // No provider means no model + } + ollama, ok := provider["ollama"].(map[string]any) + if !ok { + return // No ollama means no model + } + models, ok := ollama["models"].(map[string]any) + if !ok { + return // No models means no model + } + if models[model] != nil { + t.Errorf("model %s should not exist but was found", model) + } +} + +func assertOpenCodeRecentModel(t *testing.T, path string, index int, providerID, modelID string) { + t.Helper() + data, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + var state map[string]any + if err := json.Unmarshal(data, &state); err != nil { + t.Fatal(err) + } + recent, ok := state["recent"].([]any) + if !ok { + t.Fatal("recent not found") + } + if index >= len(recent) { + t.Fatalf("index %d out of range (len=%d)", index, len(recent)) + } + entry, ok := recent[index].(map[string]any) + if !ok { + t.Fatal("entry is not a map") + } + if entry["providerID"] != providerID { + t.Errorf("expected providerID %s, got %s", providerID, entry["providerID"]) + } + if entry["modelID"] != modelID { + t.Errorf("expected modelID %s, got %s", modelID, entry["modelID"]) + } +} + +// Edge case tests for opencode.go + +func TestOpenCodeEdit_CorruptedConfigJSON(t *testing.T) { + o := &OpenCode{} + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + configDir := filepath.Join(tmpDir, ".config", "opencode") + configPath := filepath.Join(configDir, "opencode.json") + + os.MkdirAll(configDir, 0o755) + os.WriteFile(configPath, []byte(`{corrupted json content`), 0o644) + + // Should not panic - corrupted JSON should be treated as empty + err := o.Edit([]string{"llama3.2"}) + if err != nil { + t.Fatalf("Edit failed with corrupted config: %v", err) + } + + // Verify valid JSON was created + data, _ := os.ReadFile(configPath) + var cfg map[string]any + if err := json.Unmarshal(data, &cfg); err != nil { + t.Errorf("resulting config is not valid JSON: %v", err) + } +} + +func TestOpenCodeEdit_CorruptedStateJSON(t *testing.T) { + o := &OpenCode{} + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + stateDir := filepath.Join(tmpDir, ".local", "state", "opencode") + statePath := filepath.Join(stateDir, "model.json") + + os.MkdirAll(stateDir, 0o755) + os.WriteFile(statePath, []byte(`{corrupted state`), 0o644) + + err := o.Edit([]string{"llama3.2"}) + if err != nil { + t.Fatalf("Edit failed with corrupted state: %v", err) + } + + // Verify valid state was created + data, _ := os.ReadFile(statePath) + var state map[string]any + if err := json.Unmarshal(data, &state); err != nil { + t.Errorf("resulting state is not valid JSON: %v", err) + } +} + +func TestOpenCodeEdit_WrongTypeProvider(t *testing.T) { + o := &OpenCode{} + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + configDir := filepath.Join(tmpDir, ".config", "opencode") + configPath := filepath.Join(configDir, "opencode.json") + + os.MkdirAll(configDir, 0o755) + os.WriteFile(configPath, []byte(`{"provider": "not a map"}`), 0o644) + + err := o.Edit([]string{"llama3.2"}) + if err != nil { + t.Fatalf("Edit with wrong type provider failed: %v", err) + } + + // Verify provider is now correct type + data, _ := os.ReadFile(configPath) + var cfg map[string]any + json.Unmarshal(data, &cfg) + + provider, ok := cfg["provider"].(map[string]any) + if !ok { + t.Fatalf("provider should be map after setup, got %T", cfg["provider"]) + } + if provider["ollama"] == nil { + t.Error("ollama provider should be created") + } +} + +func TestOpenCodeEdit_WrongTypeRecent(t *testing.T) { + o := &OpenCode{} + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + stateDir := filepath.Join(tmpDir, ".local", "state", "opencode") + statePath := filepath.Join(stateDir, "model.json") + + os.MkdirAll(stateDir, 0o755) + os.WriteFile(statePath, []byte(`{"recent": "not an array", "favorite": [], "variant": {}}`), 0o644) + + err := o.Edit([]string{"llama3.2"}) + if err != nil { + t.Fatalf("Edit with wrong type recent failed: %v", err) + } + + // The function should handle this gracefully + data, _ := os.ReadFile(statePath) + var state map[string]any + json.Unmarshal(data, &state) + + // recent should be properly set after setup + recent, ok := state["recent"].([]any) + if !ok { + t.Logf("Note: recent type after setup is %T (documenting behavior)", state["recent"]) + } else if len(recent) == 0 { + t.Logf("Note: recent is empty (documenting behavior)") + } +} + +func TestOpenCodeEdit_EmptyModels(t *testing.T) { + o := &OpenCode{} + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + configDir := filepath.Join(tmpDir, ".config", "opencode") + configPath := filepath.Join(configDir, "opencode.json") + + os.MkdirAll(configDir, 0o755) + originalContent := `{"provider":{"ollama":{"models":{"existing":{}}}}}` + os.WriteFile(configPath, []byte(originalContent), 0o644) + + // Empty models should be no-op + err := o.Edit([]string{}) + if err != nil { + t.Fatalf("Edit with empty models failed: %v", err) + } + + // Original content should be preserved (file not modified) + data, _ := os.ReadFile(configPath) + if string(data) != originalContent { + t.Errorf("empty models should not modify file, but content changed") + } +} + +func TestOpenCodeEdit_SpecialCharsInModelName(t *testing.T) { + o := &OpenCode{} + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + // Model name with special characters (though unusual) + specialModel := `model-with-"quotes"` + + err := o.Edit([]string{specialModel}) + if err != nil { + t.Fatalf("Edit with special chars failed: %v", err) + } + + // Verify it was stored correctly + configDir := filepath.Join(tmpDir, ".config", "opencode") + configPath := filepath.Join(configDir, "opencode.json") + data, _ := os.ReadFile(configPath) + + var cfg map[string]any + if err := json.Unmarshal(data, &cfg); err != nil { + t.Fatalf("resulting config is invalid JSON: %v", err) + } + + // Model should be accessible + provider, _ := cfg["provider"].(map[string]any) + ollama, _ := provider["ollama"].(map[string]any) + models, _ := ollama["models"].(map[string]any) + + if models[specialModel] == nil { + t.Errorf("model with special chars not found in config") + } +} + +func TestOpenCodeModels_NoConfig(t *testing.T) { + o := &OpenCode{} + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + models := o.Models() + if len(models) > 0 { + t.Errorf("expected nil/empty for missing config, got %v", models) + } +} diff --git a/cmd/config/selector.go b/cmd/config/selector.go new file mode 100644 index 000000000..1ef748599 --- /dev/null +++ b/cmd/config/selector.go @@ -0,0 +1,499 @@ +package config + +import ( + "errors" + "fmt" + "io" + "os" + "strings" + + "golang.org/x/term" +) + +// ANSI escape sequences for terminal formatting. +const ( + ansiHideCursor = "\033[?25l" + ansiShowCursor = "\033[?25h" + ansiBold = "\033[1m" + ansiReset = "\033[0m" + ansiGray = "\033[37m" + ansiClearDown = "\033[J" +) + +const maxDisplayedItems = 10 + +var errCancelled = errors.New("cancelled") + +type selectItem struct { + Name string + Description string +} + +type inputEvent int + +const ( + eventNone inputEvent = iota + eventEnter + eventEscape + eventUp + eventDown + eventTab + eventBackspace + eventChar +) + +type selectState struct { + items []selectItem + filter string + selected int + scrollOffset int +} + +func newSelectState(items []selectItem) *selectState { + return &selectState{items: items} +} + +func (s *selectState) filtered() []selectItem { + return filterItems(s.items, s.filter) +} + +func (s *selectState) handleInput(event inputEvent, char byte) (done bool, result string, err error) { + filtered := s.filtered() + + switch event { + case eventEnter: + if len(filtered) > 0 && s.selected < len(filtered) { + return true, filtered[s.selected].Name, nil + } + case eventEscape: + return true, "", errCancelled + case eventBackspace: + if len(s.filter) > 0 { + s.filter = s.filter[:len(s.filter)-1] + s.selected = 0 + s.scrollOffset = 0 + } + case eventUp: + if s.selected > 0 { + s.selected-- + if s.selected < s.scrollOffset { + s.scrollOffset = s.selected + } + } + case eventDown: + if s.selected < len(filtered)-1 { + s.selected++ + if s.selected >= s.scrollOffset+maxDisplayedItems { + s.scrollOffset = s.selected - maxDisplayedItems + 1 + } + } + case eventChar: + s.filter += string(char) + s.selected = 0 + s.scrollOffset = 0 + } + + return false, "", nil +} + +type multiSelectState struct { + items []selectItem + itemIndex map[string]int + filter string + highlighted int + scrollOffset int + checked map[int]bool + checkOrder []int + focusOnButton bool +} + +func newMultiSelectState(items []selectItem, preChecked []string) *multiSelectState { + s := &multiSelectState{ + items: items, + itemIndex: make(map[string]int, len(items)), + checked: make(map[int]bool), + } + + for i, item := range items { + s.itemIndex[item.Name] = i + } + + for _, name := range preChecked { + if idx, ok := s.itemIndex[name]; ok { + s.checked[idx] = true + s.checkOrder = append(s.checkOrder, idx) + } + } + + return s +} + +func (s *multiSelectState) filtered() []selectItem { + return filterItems(s.items, s.filter) +} + +func (s *multiSelectState) toggleItem() { + filtered := s.filtered() + if len(filtered) == 0 || s.highlighted >= len(filtered) { + return + } + + item := filtered[s.highlighted] + origIdx := s.itemIndex[item.Name] + + if s.checked[origIdx] { + delete(s.checked, origIdx) + for i, idx := range s.checkOrder { + if idx == origIdx { + s.checkOrder = append(s.checkOrder[:i], s.checkOrder[i+1:]...) + break + } + } + } else { + s.checked[origIdx] = true + s.checkOrder = append(s.checkOrder, origIdx) + } +} + +func (s *multiSelectState) handleInput(event inputEvent, char byte) (done bool, result []string, err error) { + filtered := s.filtered() + + switch event { + case eventEnter: + if s.focusOnButton && len(s.checkOrder) > 0 { + var res []string + for _, idx := range s.checkOrder { + res = append(res, s.items[idx].Name) + } + return true, res, nil + } else if !s.focusOnButton { + s.toggleItem() + } + case eventTab: + if len(s.checkOrder) > 0 { + s.focusOnButton = !s.focusOnButton + } + case eventEscape: + return true, nil, errCancelled + case eventBackspace: + if len(s.filter) > 0 { + s.filter = s.filter[:len(s.filter)-1] + s.highlighted = 0 + s.scrollOffset = 0 + s.focusOnButton = false + } + case eventUp: + if s.focusOnButton { + s.focusOnButton = false + } else if s.highlighted > 0 { + s.highlighted-- + if s.highlighted < s.scrollOffset { + s.scrollOffset = s.highlighted + } + } + case eventDown: + if s.focusOnButton { + s.focusOnButton = false + } else if s.highlighted < len(filtered)-1 { + s.highlighted++ + if s.highlighted >= s.scrollOffset+maxDisplayedItems { + s.scrollOffset = s.highlighted - maxDisplayedItems + 1 + } + } + case eventChar: + s.filter += string(char) + s.highlighted = 0 + s.scrollOffset = 0 + s.focusOnButton = false + } + + return false, nil, nil +} + +func (s *multiSelectState) selectedCount() int { + return len(s.checkOrder) +} + +// Terminal I/O handling + +type terminalState struct { + fd int + oldState *term.State +} + +func enterRawMode() (*terminalState, error) { + fd := int(os.Stdin.Fd()) + oldState, err := term.MakeRaw(fd) + if err != nil { + return nil, err + } + fmt.Fprint(os.Stderr, ansiHideCursor) + return &terminalState{fd: fd, oldState: oldState}, nil +} + +func (t *terminalState) restore() { + fmt.Fprint(os.Stderr, ansiShowCursor) + term.Restore(t.fd, t.oldState) +} + +func clearLines(n int) { + if n > 0 { + fmt.Fprintf(os.Stderr, "\033[%dA", n) + fmt.Fprint(os.Stderr, ansiClearDown) + } +} + +func parseInput(r io.Reader) (inputEvent, byte, error) { + buf := make([]byte, 3) + n, err := r.Read(buf) + if err != nil { + return 0, 0, err + } + + switch { + case n == 1 && buf[0] == 13: + return eventEnter, 0, nil + case n == 1 && (buf[0] == 3 || buf[0] == 27): + return eventEscape, 0, nil + case n == 1 && buf[0] == 9: + return eventTab, 0, nil + case n == 1 && buf[0] == 127: + return eventBackspace, 0, nil + case n == 3 && buf[0] == 27 && buf[1] == 91 && buf[2] == 65: + return eventUp, 0, nil + case n == 3 && buf[0] == 27 && buf[1] == 91 && buf[2] == 66: + return eventDown, 0, nil + case n == 1 && buf[0] >= 32 && buf[0] < 127: + return eventChar, buf[0], nil + } + + return eventNone, 0, nil +} + +// Rendering + +func renderSelect(w io.Writer, prompt string, s *selectState) int { + filtered := s.filtered() + + fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter) + lineCount := 1 + + if len(filtered) == 0 { + fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset) + lineCount++ + } else { + displayCount := min(len(filtered), maxDisplayedItems) + + for i := range displayCount { + idx := s.scrollOffset + i + if idx >= len(filtered) { + break + } + item := filtered[idx] + prefix := " " + if idx == s.selected { + prefix = " " + ansiBold + "> " + } + if item.Description != "" { + fmt.Fprintf(w, "%s%s%s %s- %s%s\r\n", prefix, item.Name, ansiReset, ansiGray, item.Description, ansiReset) + } else { + fmt.Fprintf(w, "%s%s%s\r\n", prefix, item.Name, ansiReset) + } + lineCount++ + } + + if remaining := len(filtered) - s.scrollOffset - displayCount; remaining > 0 { + fmt.Fprintf(w, " %s... and %d more%s\r\n", ansiGray, remaining, ansiReset) + lineCount++ + } + } + + return lineCount +} + +func renderMultiSelect(w io.Writer, prompt string, s *multiSelectState) int { + filtered := s.filtered() + + fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter) + lineCount := 1 + + if len(filtered) == 0 { + fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset) + lineCount++ + } else { + displayCount := min(len(filtered), maxDisplayedItems) + + for i := range displayCount { + idx := s.scrollOffset + i + if idx >= len(filtered) { + break + } + item := filtered[idx] + origIdx := s.itemIndex[item.Name] + + checkbox := "[ ]" + if s.checked[origIdx] { + checkbox = "[x]" + } + + prefix := " " + suffix := "" + if idx == s.highlighted && !s.focusOnButton { + prefix = "> " + } + if len(s.checkOrder) > 0 && s.checkOrder[0] == origIdx { + suffix = " " + ansiGray + "(default)" + ansiReset + } + + if idx == s.highlighted && !s.focusOnButton { + fmt.Fprintf(w, " %s%s %s %s%s%s\r\n", ansiBold, prefix, checkbox, item.Name, ansiReset, suffix) + } else { + fmt.Fprintf(w, " %s %s %s%s\r\n", prefix, checkbox, item.Name, suffix) + } + lineCount++ + } + + if remaining := len(filtered) - s.scrollOffset - displayCount; remaining > 0 { + fmt.Fprintf(w, " %s... and %d more%s\r\n", ansiGray, remaining, ansiReset) + lineCount++ + } + } + + fmt.Fprintf(w, "\r\n") + lineCount++ + count := s.selectedCount() + switch { + case count == 0: + fmt.Fprintf(w, " %sSelect at least one model.%s\r\n", ansiGray, ansiReset) + case s.focusOnButton: + fmt.Fprintf(w, " %s> [ Continue ]%s %s(%d selected)%s\r\n", ansiBold, ansiReset, ansiGray, count, ansiReset) + default: + fmt.Fprintf(w, " %s[ Continue ] (%d selected) - press Tab%s\r\n", ansiGray, count, ansiReset) + } + lineCount++ + + return lineCount +} + +// selectPrompt prompts the user to select a single item from a list. +func selectPrompt(prompt string, items []selectItem) (string, error) { + if len(items) == 0 { + return "", fmt.Errorf("no items to select from") + } + + ts, err := enterRawMode() + if err != nil { + return "", err + } + defer ts.restore() + + state := newSelectState(items) + var lastLineCount int + + render := func() { + clearLines(lastLineCount) + lastLineCount = renderSelect(os.Stderr, prompt, state) + } + + render() + + for { + event, char, err := parseInput(os.Stdin) + if err != nil { + return "", err + } + + done, result, err := state.handleInput(event, char) + if done { + clearLines(lastLineCount) + if err != nil { + return "", err + } + return result, nil + } + + render() + } +} + +// multiSelectPrompt prompts the user to select multiple items from a list. +func multiSelectPrompt(prompt string, items []selectItem, preChecked []string) ([]string, error) { + if len(items) == 0 { + return nil, fmt.Errorf("no items to select from") + } + + ts, err := enterRawMode() + if err != nil { + return nil, err + } + defer ts.restore() + + state := newMultiSelectState(items, preChecked) + var lastLineCount int + + render := func() { + clearLines(lastLineCount) + lastLineCount = renderMultiSelect(os.Stderr, prompt, state) + } + + render() + + for { + event, char, err := parseInput(os.Stdin) + if err != nil { + return nil, err + } + + done, result, err := state.handleInput(event, char) + if done { + clearLines(lastLineCount) + if err != nil { + return nil, err + } + return result, nil + } + + render() + } +} + +func confirmPrompt(prompt string) (bool, error) { + fd := int(os.Stdin.Fd()) + oldState, err := term.MakeRaw(fd) + if err != nil { + return false, err + } + defer term.Restore(fd, oldState) + + fmt.Fprintf(os.Stderr, "%s [y/n] ", prompt) + + buf := make([]byte, 1) + for { + if _, err := os.Stdin.Read(buf); err != nil { + return false, err + } + + switch buf[0] { + case 'Y', 'y', 13: + fmt.Fprintf(os.Stderr, "yes\r\n") + return true, nil + case 'N', 'n', 27, 3: + fmt.Fprintf(os.Stderr, "no\r\n") + return false, nil + } + } +} + +func filterItems(items []selectItem, filter string) []selectItem { + if filter == "" { + return items + } + var result []selectItem + filterLower := strings.ToLower(filter) + for _, item := range items { + if strings.Contains(strings.ToLower(item.Name), filterLower) { + result = append(result, item) + } + } + return result +} diff --git a/cmd/config/selector_test.go b/cmd/config/selector_test.go new file mode 100644 index 000000000..74e8796ee --- /dev/null +++ b/cmd/config/selector_test.go @@ -0,0 +1,913 @@ +package config + +import ( + "bytes" + "strings" + "testing" +) + +func TestFilterItems(t *testing.T) { + items := []selectItem{ + {Name: "llama3.2:latest"}, + {Name: "qwen2.5:7b"}, + {Name: "deepseek-v3:cloud"}, + {Name: "GPT-OSS:20b"}, + } + + t.Run("EmptyFilter_ReturnsAllItems", func(t *testing.T) { + result := filterItems(items, "") + if len(result) != len(items) { + t.Errorf("expected %d items, got %d", len(items), len(result)) + } + }) + + t.Run("CaseInsensitive_UppercaseFilterMatchesLowercase", func(t *testing.T) { + result := filterItems(items, "LLAMA") + if len(result) != 1 || result[0].Name != "llama3.2:latest" { + t.Errorf("expected llama3.2:latest, got %v", result) + } + }) + + t.Run("CaseInsensitive_LowercaseFilterMatchesUppercase", func(t *testing.T) { + result := filterItems(items, "gpt") + if len(result) != 1 || result[0].Name != "GPT-OSS:20b" { + t.Errorf("expected GPT-OSS:20b, got %v", result) + } + }) + + t.Run("PartialMatch", func(t *testing.T) { + result := filterItems(items, "deep") + if len(result) != 1 || result[0].Name != "deepseek-v3:cloud" { + t.Errorf("expected deepseek-v3:cloud, got %v", result) + } + }) + + t.Run("NoMatch_ReturnsEmpty", func(t *testing.T) { + result := filterItems(items, "nonexistent") + if len(result) != 0 { + t.Errorf("expected 0 items, got %d", len(result)) + } + }) +} + +func TestSelectState(t *testing.T) { + items := []selectItem{ + {Name: "item1"}, + {Name: "item2"}, + {Name: "item3"}, + } + + t.Run("InitialState", func(t *testing.T) { + s := newSelectState(items) + if s.selected != 0 { + t.Errorf("expected selected=0, got %d", s.selected) + } + if s.filter != "" { + t.Errorf("expected empty filter, got %q", s.filter) + } + if s.scrollOffset != 0 { + t.Errorf("expected scrollOffset=0, got %d", s.scrollOffset) + } + }) + + t.Run("Enter_SelectsCurrentItem", func(t *testing.T) { + s := newSelectState(items) + done, result, err := s.handleInput(eventEnter, 0) + if !done || result != "item1" || err != nil { + t.Errorf("expected (true, item1, nil), got (%v, %v, %v)", done, result, err) + } + }) + + t.Run("Enter_WithFilter_SelectsFilteredItem", func(t *testing.T) { + s := newSelectState(items) + s.filter = "item3" + done, result, err := s.handleInput(eventEnter, 0) + if !done || result != "item3" || err != nil { + t.Errorf("expected (true, item3, nil), got (%v, %v, %v)", done, result, err) + } + }) + + t.Run("Enter_EmptyFilteredList_DoesNothing", func(t *testing.T) { + s := newSelectState(items) + s.filter = "nonexistent" + done, result, err := s.handleInput(eventEnter, 0) + if done || result != "" || err != nil { + t.Errorf("expected (false, '', nil), got (%v, %v, %v)", done, result, err) + } + }) + + t.Run("Escape_ReturnsCancelledError", func(t *testing.T) { + s := newSelectState(items) + done, result, err := s.handleInput(eventEscape, 0) + if !done || result != "" || err != errCancelled { + t.Errorf("expected (true, '', errCancelled), got (%v, %v, %v)", done, result, err) + } + }) + + t.Run("Down_MovesSelection", func(t *testing.T) { + s := newSelectState(items) + s.handleInput(eventDown, 0) + if s.selected != 1 { + t.Errorf("expected selected=1, got %d", s.selected) + } + }) + + t.Run("Down_AtBottom_StaysAtBottom", func(t *testing.T) { + s := newSelectState(items) + s.selected = 2 + s.handleInput(eventDown, 0) + if s.selected != 2 { + t.Errorf("expected selected=2 (stayed at bottom), got %d", s.selected) + } + }) + + t.Run("Up_MovesSelection", func(t *testing.T) { + s := newSelectState(items) + s.selected = 2 + s.handleInput(eventUp, 0) + if s.selected != 1 { + t.Errorf("expected selected=1, got %d", s.selected) + } + }) + + t.Run("Up_AtTop_StaysAtTop", func(t *testing.T) { + s := newSelectState(items) + s.handleInput(eventUp, 0) + if s.selected != 0 { + t.Errorf("expected selected=0 (stayed at top), got %d", s.selected) + } + }) + + t.Run("Char_AppendsToFilter", func(t *testing.T) { + s := newSelectState(items) + s.handleInput(eventChar, 'i') + s.handleInput(eventChar, 't') + s.handleInput(eventChar, 'e') + s.handleInput(eventChar, 'm') + s.handleInput(eventChar, '2') + if s.filter != "item2" { + t.Errorf("expected filter='item2', got %q", s.filter) + } + filtered := s.filtered() + if len(filtered) != 1 || filtered[0].Name != "item2" { + t.Errorf("expected [item2], got %v", filtered) + } + }) + + t.Run("Char_ResetsSelectionToZero", func(t *testing.T) { + s := newSelectState(items) + s.selected = 2 + s.handleInput(eventChar, 'x') + if s.selected != 0 { + t.Errorf("expected selected=0 after typing, got %d", s.selected) + } + }) + + t.Run("Backspace_RemovesLastFilterChar", func(t *testing.T) { + s := newSelectState(items) + s.filter = "test" + s.handleInput(eventBackspace, 0) + if s.filter != "tes" { + t.Errorf("expected filter='tes', got %q", s.filter) + } + }) + + t.Run("Backspace_EmptyFilter_DoesNothing", func(t *testing.T) { + s := newSelectState(items) + s.handleInput(eventBackspace, 0) + if s.filter != "" { + t.Errorf("expected filter='', got %q", s.filter) + } + }) + + t.Run("Backspace_ResetsSelectionToZero", func(t *testing.T) { + s := newSelectState(items) + s.filter = "test" + s.selected = 2 + s.handleInput(eventBackspace, 0) + if s.selected != 0 { + t.Errorf("expected selected=0 after backspace, got %d", s.selected) + } + }) + + t.Run("Scroll_DownPastVisibleItems_ScrollsViewport", func(t *testing.T) { + // maxDisplayedItems is 10, so with 15 items we need to scroll + manyItems := make([]selectItem, 15) + for i := range manyItems { + manyItems[i] = selectItem{Name: string(rune('a' + i))} + } + s := newSelectState(manyItems) + + // move down 12 times (past the 10-item viewport) + for range 12 { + s.handleInput(eventDown, 0) + } + + if s.selected != 12 { + t.Errorf("expected selected=12, got %d", s.selected) + } + if s.scrollOffset != 3 { + t.Errorf("expected scrollOffset=3 (12-10+1), got %d", s.scrollOffset) + } + }) + + t.Run("Scroll_UpPastScrollOffset_ScrollsViewport", func(t *testing.T) { + manyItems := make([]selectItem, 15) + for i := range manyItems { + manyItems[i] = selectItem{Name: string(rune('a' + i))} + } + s := newSelectState(manyItems) + s.selected = 5 + s.scrollOffset = 5 + + s.handleInput(eventUp, 0) + + if s.selected != 4 { + t.Errorf("expected selected=4, got %d", s.selected) + } + if s.scrollOffset != 4 { + t.Errorf("expected scrollOffset=4, got %d", s.scrollOffset) + } + }) +} + +func TestMultiSelectState(t *testing.T) { + items := []selectItem{ + {Name: "item1"}, + {Name: "item2"}, + {Name: "item3"}, + } + + t.Run("InitialState_NoPrechecked", func(t *testing.T) { + s := newMultiSelectState(items, nil) + if s.highlighted != 0 { + t.Errorf("expected highlighted=0, got %d", s.highlighted) + } + if s.selectedCount() != 0 { + t.Errorf("expected 0 selected, got %d", s.selectedCount()) + } + if s.focusOnButton { + t.Error("expected focusOnButton=false initially") + } + }) + + t.Run("InitialState_WithPrechecked", func(t *testing.T) { + s := newMultiSelectState(items, []string{"item2", "item3"}) + if s.selectedCount() != 2 { + t.Errorf("expected 2 selected, got %d", s.selectedCount()) + } + if !s.checked[1] || !s.checked[2] { + t.Error("expected item2 and item3 to be checked") + } + }) + + t.Run("Prechecked_PreservesSelectionOrder", func(t *testing.T) { + // order matters: first checked = default model + s := newMultiSelectState(items, []string{"item3", "item1"}) + if len(s.checkOrder) != 2 { + t.Fatalf("expected 2 in checkOrder, got %d", len(s.checkOrder)) + } + if s.checkOrder[0] != 2 || s.checkOrder[1] != 0 { + t.Errorf("expected checkOrder=[2,0] (item3 first), got %v", s.checkOrder) + } + }) + + t.Run("Prechecked_IgnoresInvalidNames", func(t *testing.T) { + s := newMultiSelectState(items, []string{"item1", "nonexistent"}) + if s.selectedCount() != 1 { + t.Errorf("expected 1 selected (nonexistent ignored), got %d", s.selectedCount()) + } + }) + + t.Run("Toggle_ChecksUncheckedItem", func(t *testing.T) { + s := newMultiSelectState(items, nil) + s.toggleItem() + if !s.checked[0] { + t.Error("expected item1 to be checked after toggle") + } + }) + + t.Run("Toggle_UnchecksCheckedItem", func(t *testing.T) { + s := newMultiSelectState(items, []string{"item1"}) + s.toggleItem() + if s.checked[0] { + t.Error("expected item1 to be unchecked after toggle") + } + }) + + t.Run("Toggle_RemovesFromCheckOrder", func(t *testing.T) { + s := newMultiSelectState(items, []string{"item1", "item2", "item3"}) + s.highlighted = 1 // toggle item2 + s.toggleItem() + + if len(s.checkOrder) != 2 { + t.Fatalf("expected 2 in checkOrder, got %d", len(s.checkOrder)) + } + // should be [0, 2] (item1, item3) with item2 removed + if s.checkOrder[0] != 0 || s.checkOrder[1] != 2 { + t.Errorf("expected checkOrder=[0,2], got %v", s.checkOrder) + } + }) + + t.Run("Enter_TogglesWhenNotOnButton", func(t *testing.T) { + s := newMultiSelectState(items, nil) + s.handleInput(eventEnter, 0) + if !s.checked[0] { + t.Error("expected item1 to be checked after enter") + } + }) + + t.Run("Enter_OnButton_ReturnsSelection", func(t *testing.T) { + s := newMultiSelectState(items, []string{"item2", "item1"}) + s.focusOnButton = true + + done, result, err := s.handleInput(eventEnter, 0) + + if !done || err != nil { + t.Errorf("expected done=true, err=nil, got done=%v, err=%v", done, err) + } + // result should preserve selection order + if len(result) != 2 || result[0] != "item2" || result[1] != "item1" { + t.Errorf("expected [item2, item1], got %v", result) + } + }) + + t.Run("Enter_OnButton_EmptySelection_DoesNothing", func(t *testing.T) { + s := newMultiSelectState(items, nil) + s.focusOnButton = true + done, result, err := s.handleInput(eventEnter, 0) + if done || result != nil || err != nil { + t.Errorf("expected (false, nil, nil), got (%v, %v, %v)", done, result, err) + } + }) + + t.Run("Tab_SwitchesToButton_WhenHasSelection", func(t *testing.T) { + s := newMultiSelectState(items, []string{"item1"}) + s.handleInput(eventTab, 0) + if !s.focusOnButton { + t.Error("expected focus on button after tab") + } + }) + + t.Run("Tab_DoesNothing_WhenNoSelection", func(t *testing.T) { + s := newMultiSelectState(items, nil) + s.handleInput(eventTab, 0) + if s.focusOnButton { + t.Error("tab should not focus button when nothing selected") + } + }) + + t.Run("Tab_TogglesButtonFocus", func(t *testing.T) { + s := newMultiSelectState(items, []string{"item1"}) + s.handleInput(eventTab, 0) + if !s.focusOnButton { + t.Error("expected focus on button after first tab") + } + s.handleInput(eventTab, 0) + if s.focusOnButton { + t.Error("expected focus back on list after second tab") + } + }) + + t.Run("Escape_ReturnsCancelledError", func(t *testing.T) { + s := newMultiSelectState(items, []string{"item1"}) + done, result, err := s.handleInput(eventEscape, 0) + if !done || result != nil || err != errCancelled { + t.Errorf("expected (true, nil, errCancelled), got (%v, %v, %v)", done, result, err) + } + }) + + t.Run("IsDefault_TrueForFirstChecked", func(t *testing.T) { + s := newMultiSelectState(items, []string{"item2", "item1"}) + if !(len(s.checkOrder) > 0 && s.checkOrder[0] == 1) { + t.Error("expected item2 (idx 1) to be default (first checked)") + } + if len(s.checkOrder) > 0 && s.checkOrder[0] == 0 { + t.Error("expected item1 (idx 0) to NOT be default") + } + }) + + t.Run("IsDefault_FalseWhenNothingChecked", func(t *testing.T) { + s := newMultiSelectState(items, nil) + if len(s.checkOrder) > 0 && s.checkOrder[0] == 0 { + t.Error("expected isDefault=false when nothing checked") + } + }) + + t.Run("Down_MovesHighlight", func(t *testing.T) { + s := newMultiSelectState(items, nil) + s.handleInput(eventDown, 0) + if s.highlighted != 1 { + t.Errorf("expected highlighted=1, got %d", s.highlighted) + } + }) + + t.Run("Up_MovesHighlight", func(t *testing.T) { + s := newMultiSelectState(items, nil) + s.highlighted = 1 + s.handleInput(eventUp, 0) + if s.highlighted != 0 { + t.Errorf("expected highlighted=0, got %d", s.highlighted) + } + }) + + t.Run("Arrow_ReturnsFocusFromButton", func(t *testing.T) { + s := newMultiSelectState(items, []string{"item1"}) + s.focusOnButton = true + s.handleInput(eventDown, 0) + if s.focusOnButton { + t.Error("expected focus to return to list on arrow key") + } + }) + + t.Run("Char_AppendsToFilter", func(t *testing.T) { + s := newMultiSelectState(items, nil) + s.handleInput(eventChar, 'x') + if s.filter != "x" { + t.Errorf("expected filter='x', got %q", s.filter) + } + }) + + t.Run("Char_ResetsHighlightAndScroll", func(t *testing.T) { + manyItems := make([]selectItem, 15) + for i := range manyItems { + manyItems[i] = selectItem{Name: string(rune('a' + i))} + } + s := newMultiSelectState(manyItems, nil) + s.highlighted = 10 + s.scrollOffset = 5 + + s.handleInput(eventChar, 'x') + + if s.highlighted != 0 { + t.Errorf("expected highlighted=0, got %d", s.highlighted) + } + if s.scrollOffset != 0 { + t.Errorf("expected scrollOffset=0, got %d", s.scrollOffset) + } + }) + + t.Run("Backspace_RemovesLastFilterChar", func(t *testing.T) { + s := newMultiSelectState(items, nil) + s.filter = "test" + s.handleInput(eventBackspace, 0) + if s.filter != "tes" { + t.Errorf("expected filter='tes', got %q", s.filter) + } + }) + + t.Run("Backspace_RemovesFocusFromButton", func(t *testing.T) { + s := newMultiSelectState(items, []string{"item1"}) + s.filter = "x" + s.focusOnButton = true + s.handleInput(eventBackspace, 0) + if s.focusOnButton { + t.Error("expected focusOnButton=false after backspace") + } + }) +} + +func TestParseInput(t *testing.T) { + t.Run("Enter", func(t *testing.T) { + event, char, err := parseInput(bytes.NewReader([]byte{13})) + if err != nil || event != eventEnter || char != 0 { + t.Errorf("expected (eventEnter, 0, nil), got (%v, %v, %v)", event, char, err) + } + }) + + t.Run("Escape", func(t *testing.T) { + event, _, err := parseInput(bytes.NewReader([]byte{27})) + if err != nil || event != eventEscape { + t.Errorf("expected eventEscape, got %v", event) + } + }) + + t.Run("CtrlC_TreatedAsEscape", func(t *testing.T) { + event, _, err := parseInput(bytes.NewReader([]byte{3})) + if err != nil || event != eventEscape { + t.Errorf("expected eventEscape for Ctrl+C, got %v", event) + } + }) + + t.Run("Tab", func(t *testing.T) { + event, _, err := parseInput(bytes.NewReader([]byte{9})) + if err != nil || event != eventTab { + t.Errorf("expected eventTab, got %v", event) + } + }) + + t.Run("Backspace", func(t *testing.T) { + event, _, err := parseInput(bytes.NewReader([]byte{127})) + if err != nil || event != eventBackspace { + t.Errorf("expected eventBackspace, got %v", event) + } + }) + + t.Run("UpArrow", func(t *testing.T) { + event, _, err := parseInput(bytes.NewReader([]byte{27, 91, 65})) + if err != nil || event != eventUp { + t.Errorf("expected eventUp, got %v", event) + } + }) + + t.Run("DownArrow", func(t *testing.T) { + event, _, err := parseInput(bytes.NewReader([]byte{27, 91, 66})) + if err != nil || event != eventDown { + t.Errorf("expected eventDown, got %v", event) + } + }) + + t.Run("PrintableChars", func(t *testing.T) { + tests := []struct { + name string + char byte + }{ + {"lowercase", 'a'}, + {"uppercase", 'Z'}, + {"digit", '5'}, + {"space", ' '}, + {"tilde", '~'}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + event, char, err := parseInput(bytes.NewReader([]byte{tt.char})) + if err != nil || event != eventChar || char != tt.char { + t.Errorf("expected (eventChar, %q), got (%v, %q)", tt.char, event, char) + } + }) + } + }) +} + +func TestRenderSelect(t *testing.T) { + items := []selectItem{ + {Name: "item1", Description: "first item"}, + {Name: "item2"}, + } + + t.Run("ShowsPromptAndItems", func(t *testing.T) { + s := newSelectState(items) + var buf bytes.Buffer + lineCount := renderSelect(&buf, "Select:", s) + + output := buf.String() + if !strings.Contains(output, "Select:") { + t.Error("expected prompt in output") + } + if !strings.Contains(output, "item1") { + t.Error("expected item1 in output") + } + if !strings.Contains(output, "first item") { + t.Error("expected description in output") + } + if !strings.Contains(output, "item2") { + t.Error("expected item2 in output") + } + if lineCount != 3 { // 1 prompt + 2 items + t.Errorf("expected 3 lines, got %d", lineCount) + } + }) + + t.Run("EmptyFilteredList_ShowsNoMatches", func(t *testing.T) { + s := newSelectState(items) + s.filter = "xyz" + var buf bytes.Buffer + renderSelect(&buf, "Select:", s) + + if !strings.Contains(buf.String(), "no matches") { + t.Error("expected 'no matches' message") + } + }) + + t.Run("LongList_ShowsRemainingCount", func(t *testing.T) { + manyItems := make([]selectItem, 15) + for i := range manyItems { + manyItems[i] = selectItem{Name: string(rune('a' + i))} + } + s := newSelectState(manyItems) + var buf bytes.Buffer + renderSelect(&buf, "Select:", s) + + // 15 items - 10 displayed = 5 more + if !strings.Contains(buf.String(), "5 more") { + t.Error("expected '5 more' indicator") + } + }) +} + +func TestRenderMultiSelect(t *testing.T) { + items := []selectItem{ + {Name: "item1"}, + {Name: "item2"}, + } + + t.Run("ShowsCheckboxes", func(t *testing.T) { + s := newMultiSelectState(items, []string{"item1"}) + var buf bytes.Buffer + renderMultiSelect(&buf, "Select:", s) + + output := buf.String() + if !strings.Contains(output, "[x]") { + t.Error("expected checked checkbox [x]") + } + if !strings.Contains(output, "[ ]") { + t.Error("expected unchecked checkbox [ ]") + } + }) + + t.Run("ShowsDefaultMarker", func(t *testing.T) { + s := newMultiSelectState(items, []string{"item1"}) + var buf bytes.Buffer + renderMultiSelect(&buf, "Select:", s) + + if !strings.Contains(buf.String(), "(default)") { + t.Error("expected (default) marker for first checked item") + } + }) + + t.Run("ShowsSelectedCount", func(t *testing.T) { + s := newMultiSelectState(items, []string{"item1", "item2"}) + var buf bytes.Buffer + renderMultiSelect(&buf, "Select:", s) + + if !strings.Contains(buf.String(), "2 selected") { + t.Error("expected '2 selected' in output") + } + }) + + t.Run("NoSelection_ShowsHelperText", func(t *testing.T) { + s := newMultiSelectState(items, nil) + var buf bytes.Buffer + renderMultiSelect(&buf, "Select:", s) + + if !strings.Contains(buf.String(), "Select at least one") { + t.Error("expected 'Select at least one' helper text") + } + }) +} + +func TestErrCancelled(t *testing.T) { + t.Run("NotNil", func(t *testing.T) { + if errCancelled == nil { + t.Error("errCancelled should not be nil") + } + }) + + t.Run("Message", func(t *testing.T) { + if errCancelled.Error() != "cancelled" { + t.Errorf("expected 'cancelled', got %q", errCancelled.Error()) + } + }) +} + +// Edge case tests for selector.go + +// TestSelectState_SingleItem verifies that single item list works without crash. +// List with only one item should still work. +func TestSelectState_SingleItem(t *testing.T) { + items := []selectItem{{Name: "only-one"}} + + s := newSelectState(items) + + // Down should do nothing (already at bottom) + s.handleInput(eventDown, 0) + if s.selected != 0 { + t.Errorf("down on single item: expected selected=0, got %d", s.selected) + } + + // Up should do nothing (already at top) + s.handleInput(eventUp, 0) + if s.selected != 0 { + t.Errorf("up on single item: expected selected=0, got %d", s.selected) + } + + // Enter should select the only item + done, result, err := s.handleInput(eventEnter, 0) + if !done || result != "only-one" || err != nil { + t.Errorf("enter on single item: expected (true, 'only-one', nil), got (%v, %q, %v)", done, result, err) + } +} + +// TestSelectState_ExactlyMaxItems verifies boundary condition at maxDisplayedItems. +// List with exactly maxDisplayedItems items should not scroll. +func TestSelectState_ExactlyMaxItems(t *testing.T) { + items := make([]selectItem, maxDisplayedItems) + for i := range items { + items[i] = selectItem{Name: string(rune('a' + i))} + } + + s := newSelectState(items) + + // Move to last item + for range maxDisplayedItems - 1 { + s.handleInput(eventDown, 0) + } + + if s.selected != maxDisplayedItems-1 { + t.Errorf("expected selected=%d, got %d", maxDisplayedItems-1, s.selected) + } + + // Should not scroll when exactly at max + if s.scrollOffset != 0 { + t.Errorf("expected scrollOffset=0 for exactly maxDisplayedItems, got %d", s.scrollOffset) + } + + // One more down should do nothing + s.handleInput(eventDown, 0) + if s.selected != maxDisplayedItems-1 { + t.Errorf("down at max: expected selected=%d, got %d", maxDisplayedItems-1, s.selected) + } +} + +// TestFilterItems_RegexSpecialChars verifies that filter is literal, not regex. +// User typing "model.v1" shouldn't match "modelsv1". +func TestFilterItems_RegexSpecialChars(t *testing.T) { + items := []selectItem{ + {Name: "model.v1"}, + {Name: "modelsv1"}, + {Name: "model-v1"}, + } + + // Filter with dot should only match literal dot + result := filterItems(items, "model.v1") + if len(result) != 1 { + t.Errorf("expected 1 exact match, got %d", len(result)) + } + if len(result) > 0 && result[0].Name != "model.v1" { + t.Errorf("expected 'model.v1', got %s", result[0].Name) + } + + // Other regex special chars should be literal too + items2 := []selectItem{ + {Name: "test[0]"}, + {Name: "test0"}, + {Name: "test(1)"}, + } + + result2 := filterItems(items2, "test[0]") + if len(result2) != 1 || result2[0].Name != "test[0]" { + t.Errorf("expected only 'test[0]', got %v", result2) + } +} + +// TestMultiSelectState_DuplicateNames documents handling of duplicate item names. +// itemIndex uses name as key - duplicates cause collision. This documents +// the current behavior: the last index for a duplicate name is stored +func TestMultiSelectState_DuplicateNames(t *testing.T) { + // Duplicate names - this is an edge case that shouldn't happen in practice + items := []selectItem{ + {Name: "duplicate"}, + {Name: "duplicate"}, + {Name: "unique"}, + } + + s := newMultiSelectState(items, nil) + + // DOCUMENTED BEHAVIOR: itemIndex maps name to LAST index + // When there are duplicates, only the last occurrence's index is stored + if s.itemIndex["duplicate"] != 1 { + t.Errorf("itemIndex should map 'duplicate' to last index (1), got %d", s.itemIndex["duplicate"]) + } + + // Toggle item at highlighted=0 (first "duplicate") + // Due to name collision, toggleItem uses itemIndex["duplicate"] = 1 + // So it actually toggles the SECOND duplicate item, not the first + s.toggleItem() + + // This documents the potentially surprising behavior: + // We toggled at highlighted=0, but itemIndex lookup returned 1 + if !s.checked[1] { + t.Error("toggle should check index 1 (due to name collision in itemIndex)") + } + if s.checked[0] { + t.Log("Note: index 0 is NOT checked, even though highlighted=0 (name collision behavior)") + } +} + +// TestSelectState_FilterReducesBelowSelection verifies selection resets when filter reduces list. +// Prevents index-out-of-bounds on next keystroke +func TestSelectState_FilterReducesBelowSelection(t *testing.T) { + items := []selectItem{ + {Name: "apple"}, + {Name: "banana"}, + {Name: "cherry"}, + } + + s := newSelectState(items) + s.selected = 2 // Select "cherry" + + // Type a filter that removes cherry from results + s.handleInput(eventChar, 'a') // Filter to "a" - matches "apple" and "banana" + + // Selection should reset to 0 + if s.selected != 0 { + t.Errorf("expected selected=0 after filter, got %d", s.selected) + } + + filtered := s.filtered() + if len(filtered) != 2 { + t.Errorf("expected 2 filtered items, got %d", len(filtered)) + } +} + +// TestFilterItems_UnicodeCharacters verifies filtering works with UTF-8. +// Model names might contain unicode characters +func TestFilterItems_UnicodeCharacters(t *testing.T) { + items := []selectItem{ + {Name: "llama-日本語"}, + {Name: "模型-chinese"}, + {Name: "émoji-🦙"}, + {Name: "regular-model"}, + } + + t.Run("filter japanese", func(t *testing.T) { + result := filterItems(items, "日本") + if len(result) != 1 || result[0].Name != "llama-日本語" { + t.Errorf("expected llama-日本語, got %v", result) + } + }) + + t.Run("filter chinese", func(t *testing.T) { + result := filterItems(items, "模型") + if len(result) != 1 || result[0].Name != "模型-chinese" { + t.Errorf("expected 模型-chinese, got %v", result) + } + }) + + t.Run("filter emoji", func(t *testing.T) { + result := filterItems(items, "🦙") + if len(result) != 1 || result[0].Name != "émoji-🦙" { + t.Errorf("expected émoji-🦙, got %v", result) + } + }) + + t.Run("filter accented char", func(t *testing.T) { + result := filterItems(items, "émoji") + if len(result) != 1 || result[0].Name != "émoji-🦙" { + t.Errorf("expected émoji-🦙, got %v", result) + } + }) +} + +// TestMultiSelectState_FilterReducesBelowHighlight verifies highlight resets when filter reduces list. +func TestMultiSelectState_FilterReducesBelowHighlight(t *testing.T) { + items := []selectItem{ + {Name: "apple"}, + {Name: "banana"}, + {Name: "cherry"}, + } + + s := newMultiSelectState(items, nil) + s.highlighted = 2 // Highlight "cherry" + + // Type a filter that removes cherry + s.handleInput(eventChar, 'a') + + if s.highlighted != 0 { + t.Errorf("expected highlighted=0 after filter, got %d", s.highlighted) + } +} + +// TestMultiSelectState_EmptyItems verifies handling of empty item list. +// Empty list should be handled gracefully. +func TestMultiSelectState_EmptyItems(t *testing.T) { + s := newMultiSelectState([]selectItem{}, nil) + + // Toggle should not panic on empty list + s.toggleItem() + + if s.selectedCount() != 0 { + t.Errorf("expected 0 selected for empty list, got %d", s.selectedCount()) + } + + // Render should handle empty list + var buf bytes.Buffer + lineCount := renderMultiSelect(&buf, "Select:", s) + if lineCount == 0 { + t.Error("renderMultiSelect should produce output even for empty list") + } + if !strings.Contains(buf.String(), "no matches") { + t.Error("expected 'no matches' for empty list") + } +} + +// TestSelectState_RenderWithDescriptions verifies rendering items with descriptions. +func TestSelectState_RenderWithDescriptions(t *testing.T) { + items := []selectItem{ + {Name: "item1", Description: "First item description"}, + {Name: "item2", Description: ""}, + {Name: "item3", Description: "Third item"}, + } + + s := newSelectState(items) + var buf bytes.Buffer + renderSelect(&buf, "Select:", s) + + output := buf.String() + if !strings.Contains(output, "First item description") { + t.Error("expected description to be rendered") + } + if !strings.Contains(output, "item2") { + t.Error("expected item without description to be rendered") + } +}