mirror of
https://github.com/ollama/ollama.git
synced 2026-01-29 07:12:03 +03:00
cmd: ollama config command to help configure integrations to use Ollama (#13712)
This commit is contained in:
@@ -35,6 +35,7 @@ import (
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/parser"
|
||||
@@ -2026,6 +2027,7 @@ func NewCLI() *cobra.Command {
|
||||
copyCmd,
|
||||
deleteCmd,
|
||||
runnerCmd,
|
||||
config.ConfigCmd(checkServerHeartbeat),
|
||||
)
|
||||
|
||||
return rootCmd
|
||||
|
||||
36
cmd/config/claude.go
Normal file
36
cmd/config/claude.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
// Claude implements Runner for Claude Code integration
|
||||
type Claude struct{}
|
||||
|
||||
func (c *Claude) String() string { return "Claude Code" }
|
||||
|
||||
func (c *Claude) args(model string) []string {
|
||||
if model != "" {
|
||||
return []string{"--model", model}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Claude) Run(model string) error {
|
||||
if _, err := exec.LookPath("claude"); err != nil {
|
||||
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
|
||||
}
|
||||
|
||||
cmd := exec.Command("claude", c.args(model)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Env = append(os.Environ(),
|
||||
"ANTHROPIC_BASE_URL=http://localhost:11434",
|
||||
"ANTHROPIC_API_KEY=",
|
||||
"ANTHROPIC_AUTH_TOKEN=ollama",
|
||||
)
|
||||
return cmd.Run()
|
||||
}
|
||||
42
cmd/config/claude_test.go
Normal file
42
cmd/config/claude_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClaudeIntegration(t *testing.T) {
|
||||
c := &Claude{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := c.String(); got != "Claude Code" {
|
||||
t.Errorf("String() = %q, want %q", got, "Claude Code")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = c
|
||||
})
|
||||
}
|
||||
|
||||
func TestClaudeArgs(t *testing.T) {
|
||||
c := &Claude{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
want []string
|
||||
}{
|
||||
{"with model", "llama3.2", []string{"--model", "llama3.2"}},
|
||||
{"empty model", "", nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := c.args(tt.model)
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
61
cmd/config/codex.go
Normal file
61
cmd/config/codex.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/mod/semver"
|
||||
)
|
||||
|
||||
// Codex implements Runner for Codex integration
|
||||
type Codex struct{}
|
||||
|
||||
func (c *Codex) String() string { return "Codex" }
|
||||
|
||||
func (c *Codex) args(model string) []string {
|
||||
args := []string{"--oss"}
|
||||
if model != "" {
|
||||
args = append(args, "-m", model)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func (c *Codex) Run(model string) error {
|
||||
if err := checkCodexVersion(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd := exec.Command("codex", c.args(model)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func checkCodexVersion() error {
|
||||
if _, err := exec.LookPath("codex"); err != nil {
|
||||
return fmt.Errorf("codex is not installed, install with: npm install -g @openai/codex")
|
||||
}
|
||||
|
||||
out, err := exec.Command("codex", "--version").Output()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get codex version: %w", err)
|
||||
}
|
||||
|
||||
// Parse output like "codex-cli 0.87.0"
|
||||
fields := strings.Fields(strings.TrimSpace(string(out)))
|
||||
if len(fields) < 2 {
|
||||
return fmt.Errorf("unexpected codex version output: %s", string(out))
|
||||
}
|
||||
|
||||
version := "v" + fields[len(fields)-1]
|
||||
minVersion := "v0.81.0"
|
||||
|
||||
if semver.Compare(version, minVersion) < 0 {
|
||||
return fmt.Errorf("codex version %s is too old, minimum required is %s, update with: npm update -g @openai/codex", fields[len(fields)-1], "0.81.0")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
28
cmd/config/codex_test.go
Normal file
28
cmd/config/codex_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCodexArgs(t *testing.T) {
|
||||
c := &Codex{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
want []string
|
||||
}{
|
||||
{"with model", "llama3.2", []string{"--oss", "-m", "llama3.2"}},
|
||||
{"empty model", "", []string{"--oss"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := c.args(tt.model)
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
115
cmd/config/config.go
Normal file
115
cmd/config/config.go
Normal file
@@ -0,0 +1,115 @@
|
||||
// Package config provides integration configuration for external coding tools
|
||||
// (Claude Code, Codex, Droid, OpenCode) to use Ollama models.
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type integration struct {
|
||||
Models []string `json:"models"`
|
||||
}
|
||||
|
||||
type config struct {
|
||||
Integrations map[string]*integration `json:"integrations"`
|
||||
}
|
||||
|
||||
func configPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(home, ".ollama", "config", "config.json"), nil
|
||||
}
|
||||
|
||||
func load() (*config, error) {
|
||||
path, err := configPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return &config{Integrations: make(map[string]*integration)}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cfg config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config: %w, at: %s", err, path)
|
||||
}
|
||||
if cfg.Integrations == nil {
|
||||
cfg.Integrations = make(map[string]*integration)
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func save(cfg *config) error {
|
||||
path, err := configPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return writeWithBackup(path, data)
|
||||
}
|
||||
|
||||
func saveIntegration(appName string, models []string) error {
|
||||
if appName == "" {
|
||||
return errors.New("app name cannot be empty")
|
||||
}
|
||||
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg.Integrations[strings.ToLower(appName)] = &integration{
|
||||
Models: models,
|
||||
}
|
||||
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
func loadIntegration(appName string) (*integration, error) {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ic, ok := cfg.Integrations[strings.ToLower(appName)]
|
||||
if !ok {
|
||||
return nil, os.ErrNotExist
|
||||
}
|
||||
|
||||
return ic, nil
|
||||
}
|
||||
|
||||
func listIntegrations() ([]integration, error) {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]integration, 0, len(cfg.Integrations))
|
||||
for _, ic := range cfg.Integrations {
|
||||
result = append(result, *ic)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
373
cmd/config/config_test.go
Normal file
373
cmd/config/config_test.go
Normal file
@@ -0,0 +1,373 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// setTestHome sets both HOME (Unix) and USERPROFILE (Windows) for cross-platform tests
|
||||
func setTestHome(t *testing.T, dir string) {
|
||||
t.Setenv("HOME", dir)
|
||||
t.Setenv("USERPROFILE", dir)
|
||||
}
|
||||
|
||||
// editorPaths is a test helper that safely calls Paths if the runner implements Editor
|
||||
func editorPaths(r Runner) []string {
|
||||
if editor, ok := r.(Editor); ok {
|
||||
return editor.Paths()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestIntegrationConfig(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("save and load round-trip", func(t *testing.T) {
|
||||
models := []string{"llama3.2", "mistral", "qwen2.5"}
|
||||
if err := saveIntegration("claude", models); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(config.Models) != len(models) {
|
||||
t.Errorf("expected %d models, got %d", len(models), len(config.Models))
|
||||
}
|
||||
for i, m := range models {
|
||||
if config.Models[i] != m {
|
||||
t.Errorf("model %d: expected %s, got %s", i, m, config.Models[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("defaultModel returns first model", func(t *testing.T) {
|
||||
saveIntegration("codex", []string{"model-a", "model-b"})
|
||||
|
||||
config, _ := loadIntegration("codex")
|
||||
defaultModel := ""
|
||||
if len(config.Models) > 0 {
|
||||
defaultModel = config.Models[0]
|
||||
}
|
||||
if defaultModel != "model-a" {
|
||||
t.Errorf("expected model-a, got %s", defaultModel)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("defaultModel returns empty for no models", func(t *testing.T) {
|
||||
config := &integration{Models: []string{}}
|
||||
defaultModel := ""
|
||||
if len(config.Models) > 0 {
|
||||
defaultModel = config.Models[0]
|
||||
}
|
||||
if defaultModel != "" {
|
||||
t.Errorf("expected empty string, got %s", defaultModel)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("app name is case-insensitive", func(t *testing.T) {
|
||||
saveIntegration("Claude", []string{"model-x"})
|
||||
|
||||
config, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defaultModel := ""
|
||||
if len(config.Models) > 0 {
|
||||
defaultModel = config.Models[0]
|
||||
}
|
||||
if defaultModel != "model-x" {
|
||||
t.Errorf("expected model-x, got %s", defaultModel)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multiple integrations in single file", func(t *testing.T) {
|
||||
saveIntegration("app1", []string{"model-1"})
|
||||
saveIntegration("app2", []string{"model-2"})
|
||||
|
||||
config1, _ := loadIntegration("app1")
|
||||
config2, _ := loadIntegration("app2")
|
||||
|
||||
defaultModel1 := ""
|
||||
if len(config1.Models) > 0 {
|
||||
defaultModel1 = config1.Models[0]
|
||||
}
|
||||
defaultModel2 := ""
|
||||
if len(config2.Models) > 0 {
|
||||
defaultModel2 = config2.Models[0]
|
||||
}
|
||||
if defaultModel1 != "model-1" {
|
||||
t.Errorf("expected model-1, got %s", defaultModel1)
|
||||
}
|
||||
if defaultModel2 != "model-2" {
|
||||
t.Errorf("expected model-2, got %s", defaultModel2)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestListIntegrations(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("returns empty when no integrations", func(t *testing.T) {
|
||||
configs, err := listIntegrations()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(configs) != 0 {
|
||||
t.Errorf("expected 0 integrations, got %d", len(configs))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns all saved integrations", func(t *testing.T) {
|
||||
saveIntegration("claude", []string{"model-1"})
|
||||
saveIntegration("droid", []string{"model-2"})
|
||||
|
||||
configs, err := listIntegrations()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(configs) != 2 {
|
||||
t.Errorf("expected 2 integrations, got %d", len(configs))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestEditorPaths(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("returns empty for claude (no Editor)", func(t *testing.T) {
|
||||
r := integrations["claude"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 0 {
|
||||
t.Errorf("expected no paths for claude, got %v", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns empty for codex (no Editor)", func(t *testing.T) {
|
||||
r := integrations["codex"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 0 {
|
||||
t.Errorf("expected no paths for codex, got %v", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns empty for droid when no config exists", func(t *testing.T) {
|
||||
r := integrations["droid"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 0 {
|
||||
t.Errorf("expected no paths, got %v", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns path for droid when config exists", func(t *testing.T) {
|
||||
settingsDir, _ := os.UserHomeDir()
|
||||
settingsDir = filepath.Join(settingsDir, ".factory")
|
||||
os.MkdirAll(settingsDir, 0o755)
|
||||
os.WriteFile(filepath.Join(settingsDir, "settings.json"), []byte(`{}`), 0o644)
|
||||
|
||||
r := integrations["droid"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 1 {
|
||||
t.Errorf("expected 1 path, got %d", len(paths))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns paths for opencode when configs exist", func(t *testing.T) {
|
||||
home, _ := os.UserHomeDir()
|
||||
configDir := filepath.Join(home, ".config", "opencode")
|
||||
stateDir := filepath.Join(home, ".local", "state", "opencode")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "opencode.json"), []byte(`{}`), 0o644)
|
||||
os.WriteFile(filepath.Join(stateDir, "model.json"), []byte(`{}`), 0o644)
|
||||
|
||||
r := integrations["opencode"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 2 {
|
||||
t.Errorf("expected 2 paths, got %d: %v", len(paths), paths)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadIntegration_CorruptedJSON(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Create corrupted config.json file
|
||||
dir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(dir, 0o755)
|
||||
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
|
||||
|
||||
// Corrupted file is treated as empty, so loadIntegration returns not found
|
||||
_, err := loadIntegration("test")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent integration in corrupted file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveIntegration_NilModels(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
if err := saveIntegration("test", nil); err != nil {
|
||||
t.Fatalf("saveIntegration with nil models failed: %v", err)
|
||||
}
|
||||
|
||||
config, err := loadIntegration("test")
|
||||
if err != nil {
|
||||
t.Fatalf("loadIntegration failed: %v", err)
|
||||
}
|
||||
|
||||
if config.Models == nil {
|
||||
// nil is acceptable
|
||||
} else if len(config.Models) != 0 {
|
||||
t.Errorf("expected empty or nil models, got %v", config.Models)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveIntegration_EmptyAppName(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
err := saveIntegration("", []string{"model"})
|
||||
if err == nil {
|
||||
t.Error("expected error for empty app name, got nil")
|
||||
}
|
||||
if err != nil && !strings.Contains(err.Error(), "app name cannot be empty") {
|
||||
t.Errorf("expected 'app name cannot be empty' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadIntegration_NonexistentIntegration(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
_, err := loadIntegration("nonexistent")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent integration, got nil")
|
||||
}
|
||||
if !os.IsNotExist(err) {
|
||||
t.Logf("error type is os.ErrNotExist as expected: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigPath(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
path, err := configPath()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected := filepath.Join(tmpDir, ".ollama", "config", "config.json")
|
||||
if path != expected {
|
||||
t.Errorf("expected %s, got %s", expected, path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("returns empty config when file does not exist", func(t *testing.T) {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if cfg == nil {
|
||||
t.Fatal("expected non-nil config")
|
||||
}
|
||||
if cfg.Integrations == nil {
|
||||
t.Error("expected non-nil Integrations map")
|
||||
}
|
||||
if len(cfg.Integrations) != 0 {
|
||||
t.Errorf("expected empty Integrations, got %d", len(cfg.Integrations))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("loads existing config", func(t *testing.T) {
|
||||
path, _ := configPath()
|
||||
os.MkdirAll(filepath.Dir(path), 0o755)
|
||||
os.WriteFile(path, []byte(`{"integrations":{"test":{"models":["model-a"]}}}`), 0o644)
|
||||
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if cfg.Integrations["test"] == nil {
|
||||
t.Fatal("expected test integration")
|
||||
}
|
||||
if len(cfg.Integrations["test"].Models) != 1 {
|
||||
t.Errorf("expected 1 model, got %d", len(cfg.Integrations["test"].Models))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error for corrupted JSON", func(t *testing.T) {
|
||||
path, _ := configPath()
|
||||
os.MkdirAll(filepath.Dir(path), 0o755)
|
||||
os.WriteFile(path, []byte(`{corrupted`), 0o644)
|
||||
|
||||
_, err := load()
|
||||
if err == nil {
|
||||
t.Error("expected error for corrupted JSON")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSave(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("creates config file", func(t *testing.T) {
|
||||
cfg := &config{
|
||||
Integrations: map[string]*integration{
|
||||
"test": {Models: []string{"model-a", "model-b"}},
|
||||
},
|
||||
}
|
||||
|
||||
if err := save(cfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
path, _ := configPath()
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
t.Error("config file was not created")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("round-trip preserves data", func(t *testing.T) {
|
||||
cfg := &config{
|
||||
Integrations: map[string]*integration{
|
||||
"claude": {Models: []string{"llama3.2", "mistral"}},
|
||||
"codex": {Models: []string{"qwen2.5"}},
|
||||
},
|
||||
}
|
||||
|
||||
if err := save(cfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, err := load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(loaded.Integrations) != 2 {
|
||||
t.Errorf("expected 2 integrations, got %d", len(loaded.Integrations))
|
||||
}
|
||||
if loaded.Integrations["claude"] == nil {
|
||||
t.Error("missing claude integration")
|
||||
}
|
||||
if len(loaded.Integrations["claude"].Models) != 2 {
|
||||
t.Errorf("expected 2 models for claude, got %d", len(loaded.Integrations["claude"].Models))
|
||||
}
|
||||
})
|
||||
}
|
||||
174
cmd/config/droid.go
Normal file
174
cmd/config/droid.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Droid implements Runner and Editor for Droid integration
|
||||
type Droid struct{}
|
||||
|
||||
// droidModelEntry represents a custom model entry in Droid's settings.json
|
||||
type droidModelEntry struct {
|
||||
Model string `json:"model"`
|
||||
DisplayName string `json:"displayName"`
|
||||
BaseURL string `json:"baseUrl"`
|
||||
APIKey string `json:"apiKey"`
|
||||
Provider string `json:"provider"`
|
||||
MaxOutputTokens int `json:"maxOutputTokens"`
|
||||
SupportsImages bool `json:"supportsImages"`
|
||||
ID string `json:"id"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
func (d *Droid) String() string { return "Droid" }
|
||||
|
||||
func (d *Droid) Run(model string) error {
|
||||
if _, err := exec.LookPath("droid"); err != nil {
|
||||
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
|
||||
}
|
||||
|
||||
// Call Edit() to ensure config is up-to-date before launch
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("droid"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
if err := d.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("droid")
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func (d *Droid) Paths() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
p := filepath.Join(home, ".factory", "settings.json")
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return []string{p}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Droid) Edit(models []string) error {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
settingsPath := filepath.Join(home, ".factory", "settings.json")
|
||||
if err := os.MkdirAll(filepath.Dir(settingsPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
settings := make(map[string]any)
|
||||
if data, err := os.ReadFile(settingsPath); err == nil {
|
||||
if err := json.Unmarshal(data, &settings); err != nil {
|
||||
return fmt.Errorf("failed to parse settings file: %w, at: %s", err, settingsPath)
|
||||
}
|
||||
}
|
||||
|
||||
customModels, _ := settings["customModels"].([]any)
|
||||
|
||||
// Keep only non-Ollama models (we'll rebuild Ollama models fresh)
|
||||
nonOllamaModels := slices.DeleteFunc(slices.Clone(customModels), isOllamaModelEntry)
|
||||
|
||||
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
|
||||
var ollamaModels []any
|
||||
var defaultModelID string
|
||||
for i, model := range models {
|
||||
modelID := fmt.Sprintf("custom:%s-[Ollama]-%d", model, i)
|
||||
ollamaModels = append(ollamaModels, droidModelEntry{
|
||||
Model: model,
|
||||
DisplayName: model,
|
||||
BaseURL: "http://localhost:11434/v1",
|
||||
APIKey: "ollama",
|
||||
Provider: "generic-chat-completion-api",
|
||||
MaxOutputTokens: 64000,
|
||||
SupportsImages: false,
|
||||
ID: modelID,
|
||||
Index: i,
|
||||
})
|
||||
if i == 0 {
|
||||
defaultModelID = modelID
|
||||
}
|
||||
}
|
||||
|
||||
settings["customModels"] = append(ollamaModels, nonOllamaModels...)
|
||||
|
||||
sessionSettings, ok := settings["sessionDefaultSettings"].(map[string]any)
|
||||
if !ok {
|
||||
sessionSettings = make(map[string]any)
|
||||
}
|
||||
sessionSettings["model"] = defaultModelID
|
||||
|
||||
if effort, ok := sessionSettings["reasoningEffort"].(string); !ok || !isValidReasoningEffort(effort) {
|
||||
sessionSettings["reasoningEffort"] = "none"
|
||||
}
|
||||
|
||||
settings["sessionDefaultSettings"] = sessionSettings
|
||||
|
||||
data, err := json.MarshalIndent(settings, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(settingsPath, data)
|
||||
}
|
||||
|
||||
func (d *Droid) Models() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
settings, err := readJSONFile(filepath.Join(home, ".factory", "settings.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
customModels, _ := settings["customModels"].([]any)
|
||||
|
||||
var result []string
|
||||
for _, m := range customModels {
|
||||
if !isOllamaModelEntry(m) {
|
||||
continue
|
||||
}
|
||||
entry, ok := m.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if model, _ := entry["model"].(string); model != "" {
|
||||
result = append(result, model)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
var validReasoningEfforts = []string{"high", "medium", "low", "none"}
|
||||
|
||||
func isValidReasoningEffort(effort string) bool {
|
||||
return slices.Contains(validReasoningEfforts, effort)
|
||||
}
|
||||
|
||||
func isOllamaModelEntry(m any) bool {
|
||||
entry, ok := m.(map[string]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
id, _ := entry["id"].(string)
|
||||
return strings.Contains(id, "-[Ollama]-")
|
||||
}
|
||||
454
cmd/config/droid_test.go
Normal file
454
cmd/config/droid_test.go
Normal file
@@ -0,0 +1,454 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDroidIntegration(t *testing.T) {
|
||||
d := &Droid{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := d.String(); got != "Droid" {
|
||||
t.Errorf("String() = %q, want %q", got, "Droid")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = d
|
||||
})
|
||||
|
||||
t.Run("implements Editor", func(t *testing.T) {
|
||||
var _ Editor = d
|
||||
})
|
||||
}
|
||||
|
||||
func TestDroidEdit(t *testing.T) {
|
||||
d := &Droid{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
settingsDir := filepath.Join(tmpDir, ".factory")
|
||||
settingsPath := filepath.Join(settingsDir, "settings.json")
|
||||
|
||||
cleanup := func() {
|
||||
os.RemoveAll(settingsDir)
|
||||
}
|
||||
|
||||
readSettings := func() map[string]any {
|
||||
data, _ := os.ReadFile(settingsPath)
|
||||
var settings map[string]any
|
||||
json.Unmarshal(data, &settings)
|
||||
return settings
|
||||
}
|
||||
|
||||
getCustomModels := func(settings map[string]any) []map[string]any {
|
||||
models, ok := settings["customModels"].([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
var result []map[string]any
|
||||
for _, m := range models {
|
||||
if entry, ok := m.(map[string]any); ok {
|
||||
result = append(result, entry)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
t.Run("fresh install creates models with sequential indices", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := d.Edit([]string{"model-a", "model-b"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
settings := readSettings()
|
||||
models := getCustomModels(settings)
|
||||
|
||||
if len(models) != 2 {
|
||||
t.Fatalf("expected 2 models, got %d", len(models))
|
||||
}
|
||||
|
||||
// Check first model
|
||||
if models[0]["model"] != "model-a" {
|
||||
t.Errorf("expected model-a, got %s", models[0]["model"])
|
||||
}
|
||||
if models[0]["id"] != "custom:model-a-[Ollama]-0" {
|
||||
t.Errorf("expected custom:model-a-[Ollama]-0, got %s", models[0]["id"])
|
||||
}
|
||||
if models[0]["index"] != float64(0) {
|
||||
t.Errorf("expected index 0, got %v", models[0]["index"])
|
||||
}
|
||||
|
||||
// Check second model
|
||||
if models[1]["model"] != "model-b" {
|
||||
t.Errorf("expected model-b, got %s", models[1]["model"])
|
||||
}
|
||||
if models[1]["id"] != "custom:model-b-[Ollama]-1" {
|
||||
t.Errorf("expected custom:model-b-[Ollama]-1, got %s", models[1]["id"])
|
||||
}
|
||||
if models[1]["index"] != float64(1) {
|
||||
t.Errorf("expected index 1, got %v", models[1]["index"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sets sessionDefaultSettings.model to first model ID", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := d.Edit([]string{"model-a", "model-b"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
settings := readSettings()
|
||||
session, ok := settings["sessionDefaultSettings"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("sessionDefaultSettings not found")
|
||||
}
|
||||
if session["model"] != "custom:model-a-[Ollama]-0" {
|
||||
t.Errorf("expected custom:model-a-[Ollama]-0, got %s", session["model"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("re-indexes when models removed", func(t *testing.T) {
|
||||
cleanup()
|
||||
// Add three models
|
||||
d.Edit([]string{"model-a", "model-b", "model-c"})
|
||||
|
||||
// Remove middle model
|
||||
d.Edit([]string{"model-a", "model-c"})
|
||||
|
||||
settings := readSettings()
|
||||
models := getCustomModels(settings)
|
||||
|
||||
if len(models) != 2 {
|
||||
t.Fatalf("expected 2 models, got %d", len(models))
|
||||
}
|
||||
|
||||
// Check indices are sequential 0, 1
|
||||
if models[0]["index"] != float64(0) {
|
||||
t.Errorf("expected index 0, got %v", models[0]["index"])
|
||||
}
|
||||
if models[1]["index"] != float64(1) {
|
||||
t.Errorf("expected index 1, got %v", models[1]["index"])
|
||||
}
|
||||
|
||||
// Check IDs match new indices
|
||||
if models[0]["id"] != "custom:model-a-[Ollama]-0" {
|
||||
t.Errorf("expected custom:model-a-[Ollama]-0, got %s", models[0]["id"])
|
||||
}
|
||||
if models[1]["id"] != "custom:model-c-[Ollama]-1" {
|
||||
t.Errorf("expected custom:model-c-[Ollama]-1, got %s", models[1]["id"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserves non-Ollama custom models", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(settingsDir, 0o755)
|
||||
// Pre-existing non-Ollama model
|
||||
os.WriteFile(settingsPath, []byte(`{
|
||||
"customModels": [
|
||||
{"model": "gpt-4", "displayName": "GPT-4", "provider": "openai"}
|
||||
]
|
||||
}`), 0o644)
|
||||
|
||||
d.Edit([]string{"model-a"})
|
||||
|
||||
settings := readSettings()
|
||||
models := getCustomModels(settings)
|
||||
|
||||
if len(models) != 2 {
|
||||
t.Fatalf("expected 2 models (1 Ollama + 1 non-Ollama), got %d", len(models))
|
||||
}
|
||||
|
||||
// Ollama model should be first
|
||||
if models[0]["model"] != "model-a" {
|
||||
t.Errorf("expected Ollama model first, got %s", models[0]["model"])
|
||||
}
|
||||
|
||||
// Non-Ollama model should be preserved at end
|
||||
if models[1]["model"] != "gpt-4" {
|
||||
t.Errorf("expected gpt-4 preserved, got %s", models[1]["model"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserves other settings", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(settingsDir, 0o755)
|
||||
os.WriteFile(settingsPath, []byte(`{
|
||||
"theme": "dark",
|
||||
"enableHooks": true,
|
||||
"sessionDefaultSettings": {"autonomyMode": "auto-high"}
|
||||
}`), 0o644)
|
||||
|
||||
d.Edit([]string{"model-a"})
|
||||
|
||||
settings := readSettings()
|
||||
|
||||
if settings["theme"] != "dark" {
|
||||
t.Error("theme was not preserved")
|
||||
}
|
||||
if settings["enableHooks"] != true {
|
||||
t.Error("enableHooks was not preserved")
|
||||
}
|
||||
|
||||
session := settings["sessionDefaultSettings"].(map[string]any)
|
||||
if session["autonomyMode"] != "auto-high" {
|
||||
t.Error("autonomyMode was not preserved")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("required fields present", func(t *testing.T) {
|
||||
cleanup()
|
||||
d.Edit([]string{"test-model"})
|
||||
|
||||
settings := readSettings()
|
||||
models := getCustomModels(settings)
|
||||
|
||||
if len(models) != 1 {
|
||||
t.Fatal("expected 1 model")
|
||||
}
|
||||
|
||||
model := models[0]
|
||||
requiredFields := []string{"model", "displayName", "baseUrl", "apiKey", "provider", "maxOutputTokens", "id", "index"}
|
||||
for _, field := range requiredFields {
|
||||
if model[field] == nil {
|
||||
t.Errorf("missing required field: %s", field)
|
||||
}
|
||||
}
|
||||
|
||||
if model["baseUrl"] != "http://localhost:11434/v1" {
|
||||
t.Errorf("unexpected baseUrl: %s", model["baseUrl"])
|
||||
}
|
||||
if model["apiKey"] != "ollama" {
|
||||
t.Errorf("unexpected apiKey: %s", model["apiKey"])
|
||||
}
|
||||
if model["provider"] != "generic-chat-completion-api" {
|
||||
t.Errorf("unexpected provider: %s", model["provider"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("fixes invalid reasoningEffort", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(settingsDir, 0o755)
|
||||
// Pre-existing settings with invalid reasoningEffort
|
||||
os.WriteFile(settingsPath, []byte(`{
|
||||
"sessionDefaultSettings": {"reasoningEffort": "off"}
|
||||
}`), 0o644)
|
||||
|
||||
d.Edit([]string{"model-a"})
|
||||
|
||||
settings := readSettings()
|
||||
session := settings["sessionDefaultSettings"].(map[string]any)
|
||||
|
||||
if session["reasoningEffort"] != "none" {
|
||||
t.Errorf("expected reasoningEffort to be fixed to 'none', got %s", session["reasoningEffort"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserves valid reasoningEffort", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(settingsDir, 0o755)
|
||||
os.WriteFile(settingsPath, []byte(`{
|
||||
"sessionDefaultSettings": {"reasoningEffort": "high"}
|
||||
}`), 0o644)
|
||||
|
||||
d.Edit([]string{"model-a"})
|
||||
|
||||
settings := readSettings()
|
||||
session := settings["sessionDefaultSettings"].(map[string]any)
|
||||
|
||||
if session["reasoningEffort"] != "high" {
|
||||
t.Errorf("expected reasoningEffort to remain 'high', got %s", session["reasoningEffort"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Edge case tests for droid.go
|
||||
|
||||
func TestDroidEdit_CorruptedJSON(t *testing.T) {
|
||||
d := &Droid{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
settingsDir := filepath.Join(tmpDir, ".factory")
|
||||
settingsPath := filepath.Join(settingsDir, "settings.json")
|
||||
|
||||
os.MkdirAll(settingsDir, 0o755)
|
||||
os.WriteFile(settingsPath, []byte(`{corrupted json content`), 0o644)
|
||||
|
||||
// Corrupted JSON should return an error so user knows something is wrong
|
||||
err := d.Edit([]string{"model-a"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for corrupted JSON, got nil")
|
||||
}
|
||||
|
||||
// Original corrupted file should be preserved (not overwritten)
|
||||
data, _ := os.ReadFile(settingsPath)
|
||||
if string(data) != `{corrupted json content` {
|
||||
t.Errorf("corrupted file was modified: got %s", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDroidEdit_WrongTypeCustomModels(t *testing.T) {
|
||||
d := &Droid{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
settingsDir := filepath.Join(tmpDir, ".factory")
|
||||
settingsPath := filepath.Join(settingsDir, "settings.json")
|
||||
|
||||
os.MkdirAll(settingsDir, 0o755)
|
||||
// customModels is a string instead of array
|
||||
os.WriteFile(settingsPath, []byte(`{"customModels": "not an array"}`), 0o644)
|
||||
|
||||
// Should not panic - wrong type should be handled gracefully
|
||||
err := d.Edit([]string{"model-a"})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit failed with wrong type customModels: %v", err)
|
||||
}
|
||||
|
||||
// Verify models were added correctly
|
||||
data, _ := os.ReadFile(settingsPath)
|
||||
var settings map[string]any
|
||||
json.Unmarshal(data, &settings)
|
||||
|
||||
customModels, ok := settings["customModels"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("customModels should be array after setup, got %T", settings["customModels"])
|
||||
}
|
||||
if len(customModels) != 1 {
|
||||
t.Errorf("expected 1 model, got %d", len(customModels))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDroidEdit_EmptyModels(t *testing.T) {
|
||||
d := &Droid{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
settingsDir := filepath.Join(tmpDir, ".factory")
|
||||
settingsPath := filepath.Join(settingsDir, "settings.json")
|
||||
|
||||
os.MkdirAll(settingsDir, 0o755)
|
||||
originalContent := `{"customModels": [{"model": "existing"}]}`
|
||||
os.WriteFile(settingsPath, []byte(originalContent), 0o644)
|
||||
|
||||
// Empty models should be no-op
|
||||
err := d.Edit([]string{})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit with empty models failed: %v", err)
|
||||
}
|
||||
|
||||
// Original content should be preserved (file not modified)
|
||||
data, _ := os.ReadFile(settingsPath)
|
||||
if string(data) != originalContent {
|
||||
t.Errorf("empty models should not modify file, but content changed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDroidEdit_DuplicateModels(t *testing.T) {
|
||||
d := &Droid{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
settingsDir := filepath.Join(tmpDir, ".factory")
|
||||
settingsPath := filepath.Join(settingsDir, "settings.json")
|
||||
|
||||
// Add same model twice
|
||||
err := d.Edit([]string{"model-a", "model-a"})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit with duplicates failed: %v", err)
|
||||
}
|
||||
|
||||
settings, err := readJSONFile(settingsPath)
|
||||
if err != nil {
|
||||
t.Fatalf("readJSONFile failed: %v", err)
|
||||
}
|
||||
|
||||
customModels, _ := settings["customModels"].([]any)
|
||||
// Document current behavior: duplicates are kept as separate entries
|
||||
if len(customModels) != 2 {
|
||||
t.Logf("Note: duplicates result in %d entries (documenting behavior)", len(customModels))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDroidEdit_MalformedModelEntry(t *testing.T) {
|
||||
d := &Droid{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
settingsDir := filepath.Join(tmpDir, ".factory")
|
||||
settingsPath := filepath.Join(settingsDir, "settings.json")
|
||||
|
||||
os.MkdirAll(settingsDir, 0o755)
|
||||
// Model entry is a string instead of a map
|
||||
os.WriteFile(settingsPath, []byte(`{"customModels": ["not a map", 123]}`), 0o644)
|
||||
|
||||
err := d.Edit([]string{"model-a"})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit with malformed entries failed: %v", err)
|
||||
}
|
||||
|
||||
// Malformed entries should be preserved in nonOllamaModels
|
||||
settings, _ := readJSONFile(settingsPath)
|
||||
customModels, _ := settings["customModels"].([]any)
|
||||
|
||||
// Should have: 1 new Ollama model + 2 preserved malformed entries
|
||||
if len(customModels) != 3 {
|
||||
t.Errorf("expected 3 entries (1 new + 2 preserved malformed), got %d", len(customModels))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDroidEdit_WrongTypeSessionSettings(t *testing.T) {
|
||||
d := &Droid{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
settingsDir := filepath.Join(tmpDir, ".factory")
|
||||
settingsPath := filepath.Join(settingsDir, "settings.json")
|
||||
|
||||
os.MkdirAll(settingsDir, 0o755)
|
||||
// sessionDefaultSettings is a string instead of map
|
||||
os.WriteFile(settingsPath, []byte(`{"sessionDefaultSettings": "not a map"}`), 0o644)
|
||||
|
||||
err := d.Edit([]string{"model-a"})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit with wrong type sessionDefaultSettings failed: %v", err)
|
||||
}
|
||||
|
||||
// Should create proper sessionDefaultSettings
|
||||
settings, _ := readJSONFile(settingsPath)
|
||||
session, ok := settings["sessionDefaultSettings"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("sessionDefaultSettings should be map after setup, got %T", settings["sessionDefaultSettings"])
|
||||
}
|
||||
if session["model"] == nil {
|
||||
t.Error("expected model to be set in sessionDefaultSettings")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsValidReasoningEffort(t *testing.T) {
|
||||
tests := []struct {
|
||||
effort string
|
||||
valid bool
|
||||
}{
|
||||
{"high", true},
|
||||
{"medium", true},
|
||||
{"low", true},
|
||||
{"none", true},
|
||||
{"off", false},
|
||||
{"", false},
|
||||
{"HIGH", false}, // case sensitive
|
||||
{"max", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.effort, func(t *testing.T) {
|
||||
got := isValidReasoningEffort(tt.effort)
|
||||
if got != tt.valid {
|
||||
t.Errorf("isValidReasoningEffort(%q) = %v, want %v", tt.effort, got, tt.valid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
99
cmd/config/files.go
Normal file
99
cmd/config/files.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
func readJSONFile(path string) (map[string]any, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func copyFile(src, dst string) error {
|
||||
info, err := os.Stat(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := os.ReadFile(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(dst, data, info.Mode().Perm())
|
||||
}
|
||||
|
||||
func backupDir() string {
|
||||
return filepath.Join(os.TempDir(), "ollama-backups")
|
||||
}
|
||||
|
||||
func backupToTmp(srcPath string) (string, error) {
|
||||
dir := backupDir()
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
backupPath := filepath.Join(dir, fmt.Sprintf("%s.%d", filepath.Base(srcPath), time.Now().Unix()))
|
||||
if err := copyFile(srcPath, backupPath); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return backupPath, nil
|
||||
}
|
||||
|
||||
// writeWithBackup writes data to path via temp file + rename, backing up any existing file first
|
||||
func writeWithBackup(path string, data []byte) error {
|
||||
var backupPath string
|
||||
// backup must be created before any writes to the target file
|
||||
if existingContent, err := os.ReadFile(path); err == nil {
|
||||
if !bytes.Equal(existingContent, data) {
|
||||
backupPath, err = backupToTmp(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("backup failed: %w", err)
|
||||
}
|
||||
}
|
||||
} else if !os.IsNotExist(err) {
|
||||
return fmt.Errorf("read existing file: %w", err)
|
||||
}
|
||||
|
||||
dir := filepath.Dir(path)
|
||||
tmp, err := os.CreateTemp(dir, ".tmp-*")
|
||||
if err != nil {
|
||||
return fmt.Errorf("create temp failed: %w", err)
|
||||
}
|
||||
tmpPath := tmp.Name()
|
||||
|
||||
if _, err := tmp.Write(data); err != nil {
|
||||
_ = tmp.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
return fmt.Errorf("write failed: %w", err)
|
||||
}
|
||||
if err := tmp.Sync(); err != nil {
|
||||
_ = tmp.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
return fmt.Errorf("sync failed: %w", err)
|
||||
}
|
||||
if err := tmp.Close(); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return fmt.Errorf("close failed: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Rename(tmpPath, path); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
if backupPath != "" {
|
||||
_ = copyFile(backupPath, path)
|
||||
}
|
||||
return fmt.Errorf("rename failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
502
cmd/config/files_test.go
Normal file
502
cmd/config/files_test.go
Normal file
@@ -0,0 +1,502 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func mustMarshal(t *testing.T, v any) []byte {
|
||||
t.Helper()
|
||||
data, err := json.MarshalIndent(v, "", " ")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func TestWriteWithBackup(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
t.Run("creates file", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "new.json")
|
||||
data := mustMarshal(t, map[string]string{"key": "value"})
|
||||
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var result map[string]string
|
||||
if err := json.Unmarshal(content, &result); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result["key"] != "value" {
|
||||
t.Errorf("expected value, got %s", result["key"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("creates backup in /tmp/ollama-backups", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "backup.json")
|
||||
|
||||
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||
|
||||
data := mustMarshal(t, map[string]bool{"updated": true})
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(backupDir())
|
||||
if err != nil {
|
||||
t.Fatal("backup directory not created")
|
||||
}
|
||||
|
||||
var foundBackup bool
|
||||
for _, entry := range entries {
|
||||
if filepath.Ext(entry.Name()) != ".json" {
|
||||
name := entry.Name()
|
||||
if len(name) > len("backup.json.") && name[:len("backup.json.")] == "backup.json." {
|
||||
backupPath := filepath.Join(backupDir(), name)
|
||||
backup, err := os.ReadFile(backupPath)
|
||||
if err == nil {
|
||||
var backupData map[string]bool
|
||||
json.Unmarshal(backup, &backupData)
|
||||
if backupData["original"] {
|
||||
foundBackup = true
|
||||
os.Remove(backupPath)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundBackup {
|
||||
t.Error("backup file not created in /tmp/ollama-backups")
|
||||
}
|
||||
|
||||
current, _ := os.ReadFile(path)
|
||||
var currentData map[string]bool
|
||||
json.Unmarshal(current, ¤tData)
|
||||
if !currentData["updated"] {
|
||||
t.Error("file doesn't contain updated data")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no backup for new file", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "nobak.json")
|
||||
|
||||
data := mustMarshal(t, map[string]string{"new": "file"})
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries, _ := os.ReadDir(backupDir())
|
||||
for _, entry := range entries {
|
||||
if len(entry.Name()) > len("nobak.json.") && entry.Name()[:len("nobak.json.")] == "nobak.json." {
|
||||
t.Error("backup should not exist for new file")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no backup when content unchanged", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "unchanged.json")
|
||||
|
||||
data := mustMarshal(t, map[string]string{"key": "value"})
|
||||
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries1, _ := os.ReadDir(backupDir())
|
||||
countBefore := 0
|
||||
for _, e := range entries1 {
|
||||
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
||||
countBefore++
|
||||
}
|
||||
}
|
||||
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries2, _ := os.ReadDir(backupDir())
|
||||
countAfter := 0
|
||||
for _, e := range entries2 {
|
||||
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
||||
countAfter++
|
||||
}
|
||||
}
|
||||
|
||||
if countAfter != countBefore {
|
||||
t.Errorf("backup was created when content unchanged (before=%d, after=%d)", countBefore, countAfter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("backup filename contains unix timestamp", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "timestamped.json")
|
||||
|
||||
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
|
||||
data := mustMarshal(t, map[string]int{"v": 2})
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries, _ := os.ReadDir(backupDir())
|
||||
var found bool
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
if len(name) > len("timestamped.json.") && name[:len("timestamped.json.")] == "timestamped.json." {
|
||||
timestamp := name[len("timestamped.json."):]
|
||||
for _, c := range timestamp {
|
||||
if c < '0' || c > '9' {
|
||||
t.Errorf("backup filename timestamp contains non-numeric character: %s", name)
|
||||
}
|
||||
}
|
||||
found = true
|
||||
os.Remove(filepath.Join(backupDir(), name))
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("backup file with timestamp not found")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Edge case tests for files.go
|
||||
|
||||
// TestWriteWithBackup_FailsIfBackupFails documents critical behavior: if backup fails, we must not proceed.
|
||||
// User could lose their config with no way to recover.
|
||||
func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "config.json")
|
||||
|
||||
// Create original file
|
||||
originalContent := []byte(`{"original": true}`)
|
||||
os.WriteFile(path, originalContent, 0o644)
|
||||
|
||||
// Make backup directory read-only to force backup failure
|
||||
backupDir := backupDir()
|
||||
os.MkdirAll(backupDir, 0o755)
|
||||
os.Chmod(backupDir, 0o444) // Read-only
|
||||
defer os.Chmod(backupDir, 0o755)
|
||||
|
||||
newContent := []byte(`{"updated": true}`)
|
||||
err := writeWithBackup(path, newContent)
|
||||
|
||||
// Should fail because backup couldn't be created
|
||||
if err == nil {
|
||||
t.Error("expected error when backup fails, got nil")
|
||||
}
|
||||
|
||||
// Original file should be preserved
|
||||
current, _ := os.ReadFile(path)
|
||||
if string(current) != string(originalContent) {
|
||||
t.Errorf("original file was modified despite backup failure: got %s", string(current))
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_PermissionDenied verifies clear error when target file has wrong permissions.
|
||||
// Common issue when config owned by root or wrong perms.
|
||||
func TestWriteWithBackup_PermissionDenied(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Create a read-only directory
|
||||
readOnlyDir := filepath.Join(tmpDir, "readonly")
|
||||
os.MkdirAll(readOnlyDir, 0o755)
|
||||
os.Chmod(readOnlyDir, 0o444)
|
||||
defer os.Chmod(readOnlyDir, 0o755)
|
||||
|
||||
path := filepath.Join(readOnlyDir, "config.json")
|
||||
err := writeWithBackup(path, []byte(`{"test": true}`))
|
||||
|
||||
if err == nil {
|
||||
t.Error("expected permission error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_DirectoryDoesNotExist verifies behavior when target directory doesn't exist.
|
||||
// writeWithBackup doesn't create directories - caller is responsible.
|
||||
func TestWriteWithBackup_DirectoryDoesNotExist(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "nonexistent", "subdir", "config.json")
|
||||
|
||||
err := writeWithBackup(path, []byte(`{"test": true}`))
|
||||
|
||||
// Should fail because directory doesn't exist
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent directory, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_SymlinkTarget documents behavior when target is a symlink.
|
||||
// Documents what happens if user symlinks their config file.
|
||||
func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("symlink tests may require admin on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
realFile := filepath.Join(tmpDir, "real.json")
|
||||
symlink := filepath.Join(tmpDir, "link.json")
|
||||
|
||||
// Create real file and symlink
|
||||
os.WriteFile(realFile, []byte(`{"v": 1}`), 0o644)
|
||||
os.Symlink(realFile, symlink)
|
||||
|
||||
// Write through symlink
|
||||
err := writeWithBackup(symlink, []byte(`{"v": 2}`))
|
||||
if err != nil {
|
||||
t.Fatalf("writeWithBackup through symlink failed: %v", err)
|
||||
}
|
||||
|
||||
// The real file should be updated (symlink followed for temp file creation)
|
||||
content, _ := os.ReadFile(symlink)
|
||||
if string(content) != `{"v": 2}` {
|
||||
t.Errorf("symlink target not updated correctly: got %s", string(content))
|
||||
}
|
||||
}
|
||||
|
||||
// TestBackupToTmp_SpecialCharsInFilename verifies backup works with special characters.
|
||||
// User may have config files with unusual names.
|
||||
func TestBackupToTmp_SpecialCharsInFilename(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// File with spaces and special chars
|
||||
path := filepath.Join(tmpDir, "my config (backup).json")
|
||||
os.WriteFile(path, []byte(`{"test": true}`), 0o644)
|
||||
|
||||
backupPath, err := backupToTmp(path)
|
||||
if err != nil {
|
||||
t.Fatalf("backupToTmp with special chars failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify backup exists and has correct content
|
||||
content, err := os.ReadFile(backupPath)
|
||||
if err != nil {
|
||||
t.Fatalf("could not read backup: %v", err)
|
||||
}
|
||||
if string(content) != `{"test": true}` {
|
||||
t.Errorf("backup content mismatch: got %s", string(content))
|
||||
}
|
||||
|
||||
os.Remove(backupPath)
|
||||
}
|
||||
|
||||
// TestCopyFile_PreservesPermissions verifies that copyFile preserves file permissions.
|
||||
func TestCopyFile_PreservesPermissions(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission preservation tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
src := filepath.Join(tmpDir, "src.json")
|
||||
dst := filepath.Join(tmpDir, "dst.json")
|
||||
|
||||
// Create source with specific permissions
|
||||
os.WriteFile(src, []byte(`{"test": true}`), 0o600)
|
||||
|
||||
err := copyFile(src, dst)
|
||||
if err != nil {
|
||||
t.Fatalf("copyFile failed: %v", err)
|
||||
}
|
||||
|
||||
srcInfo, _ := os.Stat(src)
|
||||
dstInfo, _ := os.Stat(dst)
|
||||
|
||||
if srcInfo.Mode().Perm() != dstInfo.Mode().Perm() {
|
||||
t.Errorf("permissions not preserved: src=%v, dst=%v", srcInfo.Mode().Perm(), dstInfo.Mode().Perm())
|
||||
}
|
||||
}
|
||||
|
||||
// TestCopyFile_SourceNotFound verifies clear error when source doesn't exist.
|
||||
func TestCopyFile_SourceNotFound(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
src := filepath.Join(tmpDir, "nonexistent.json")
|
||||
dst := filepath.Join(tmpDir, "dst.json")
|
||||
|
||||
err := copyFile(src, dst)
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent source, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_TargetIsDirectory verifies error when path points to a directory.
|
||||
func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
dirPath := filepath.Join(tmpDir, "actualdir")
|
||||
os.MkdirAll(dirPath, 0o755)
|
||||
|
||||
err := writeWithBackup(dirPath, []byte(`{"test": true}`))
|
||||
if err == nil {
|
||||
t.Error("expected error when target is a directory, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_EmptyData verifies writing zero bytes works correctly.
|
||||
func TestWriteWithBackup_EmptyData(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "empty.json")
|
||||
|
||||
err := writeWithBackup(path, []byte{})
|
||||
if err != nil {
|
||||
t.Fatalf("writeWithBackup with empty data failed: %v", err)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("could not read file: %v", err)
|
||||
}
|
||||
if len(content) != 0 {
|
||||
t.Errorf("expected empty file, got %d bytes", len(content))
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_FileUnreadableButDirWritable verifies behavior when existing file
|
||||
// cannot be read (for backup comparison) but directory is writable.
|
||||
func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "unreadable.json")
|
||||
|
||||
// Create file and make it unreadable
|
||||
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||
os.Chmod(path, 0o000)
|
||||
defer os.Chmod(path, 0o644)
|
||||
|
||||
// Should fail because we can't read the file to compare/backup
|
||||
err := writeWithBackup(path, []byte(`{"updated": true}`))
|
||||
if err == nil {
|
||||
t.Error("expected error when file is unreadable, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_RapidSuccessiveWrites verifies backup works with multiple writes
|
||||
// within the same second (timestamp collision scenario).
|
||||
func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "rapid.json")
|
||||
|
||||
// Create initial file
|
||||
os.WriteFile(path, []byte(`{"v": 0}`), 0o644)
|
||||
|
||||
// Rapid successive writes
|
||||
for i := 1; i <= 3; i++ {
|
||||
data := []byte(fmt.Sprintf(`{"v": %d}`, i))
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatalf("write %d failed: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify final content
|
||||
content, _ := os.ReadFile(path)
|
||||
if string(content) != `{"v": 3}` {
|
||||
t.Errorf("expected final content {\"v\": 3}, got %s", string(content))
|
||||
}
|
||||
|
||||
// Verify at least one backup exists
|
||||
entries, _ := os.ReadDir(backupDir())
|
||||
var backupCount int
|
||||
for _, e := range entries {
|
||||
if len(e.Name()) > len("rapid.json.") && e.Name()[:len("rapid.json.")] == "rapid.json." {
|
||||
backupCount++
|
||||
}
|
||||
}
|
||||
if backupCount == 0 {
|
||||
t.Error("expected at least one backup file from rapid writes")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_BackupDirIsFile verifies error when backup directory path is a file.
|
||||
func TestWriteWithBackup_BackupDirIsFile(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("test modifies system temp directory")
|
||||
}
|
||||
|
||||
// Create a file at the backup directory path
|
||||
backupPath := backupDir()
|
||||
// Clean up any existing directory first
|
||||
os.RemoveAll(backupPath)
|
||||
// Create a file instead of directory
|
||||
os.WriteFile(backupPath, []byte("not a directory"), 0o644)
|
||||
defer func() {
|
||||
os.Remove(backupPath)
|
||||
os.MkdirAll(backupPath, 0o755)
|
||||
}()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "test.json")
|
||||
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||
|
||||
err := writeWithBackup(path, []byte(`{"updated": true}`))
|
||||
if err == nil {
|
||||
t.Error("expected error when backup dir is a file, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_NoOrphanTempFiles verifies temp files are cleaned up on failure.
|
||||
func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Count existing temp files
|
||||
countTempFiles := func() int {
|
||||
entries, _ := os.ReadDir(tmpDir)
|
||||
count := 0
|
||||
for _, e := range entries {
|
||||
if len(e.Name()) > 4 && e.Name()[:4] == ".tmp" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
before := countTempFiles()
|
||||
|
||||
// Create a file, then make directory read-only to cause rename failure
|
||||
path := filepath.Join(tmpDir, "orphan.json")
|
||||
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
|
||||
|
||||
// Make a subdirectory and try to write there after making parent read-only
|
||||
subDir := filepath.Join(tmpDir, "subdir")
|
||||
os.MkdirAll(subDir, 0o755)
|
||||
subPath := filepath.Join(subDir, "config.json")
|
||||
os.WriteFile(subPath, []byte(`{"v": 1}`), 0o644)
|
||||
|
||||
// Make subdir read-only after creating temp file would succeed but rename would fail
|
||||
// This is tricky to test - the temp file is created in the same dir, so if we can't
|
||||
// rename, we also couldn't create. Let's just verify normal failure cleanup works.
|
||||
|
||||
// Force a failure by making the target a directory
|
||||
badPath := filepath.Join(tmpDir, "isdir")
|
||||
os.MkdirAll(badPath, 0o755)
|
||||
|
||||
_ = writeWithBackup(badPath, []byte(`{"test": true}`))
|
||||
|
||||
after := countTempFiles()
|
||||
if after > before {
|
||||
t.Errorf("orphan temp files left behind: before=%d, after=%d", before, after)
|
||||
}
|
||||
}
|
||||
361
cmd/config/integrations.go
Normal file
361
cmd/config/integrations.go
Normal file
@@ -0,0 +1,361 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// Runners execute the launching of a model with the integration - claude, codex
|
||||
// Editors can edit config files (supports multi-model selection) - opencode, droid
|
||||
// They are composable interfaces where in some cases an editor is also a runner - opencode, droid
|
||||
// Runner can run an integration with a model.
|
||||
|
||||
type Runner interface {
|
||||
Run(model string) error
|
||||
// String returns the human-readable name of the integration
|
||||
String() string
|
||||
}
|
||||
|
||||
// Editor can edit config files (supports multi-model selection)
|
||||
type Editor interface {
|
||||
// Paths returns the paths to the config files for the integration
|
||||
Paths() []string
|
||||
// Edit updates the config files for the integration with the given models
|
||||
Edit(models []string) error
|
||||
// Models returns the models currently configured for the integration
|
||||
Models() []string
|
||||
}
|
||||
|
||||
// integrations is the registry of available integrations.
|
||||
var integrations = map[string]Runner{
|
||||
"claude": &Claude{},
|
||||
"codex": &Codex{},
|
||||
"droid": &Droid{},
|
||||
"opencode": &OpenCode{},
|
||||
}
|
||||
|
||||
func selectIntegration() (string, error) {
|
||||
if len(integrations) == 0 {
|
||||
return "", fmt.Errorf("no integrations available")
|
||||
}
|
||||
|
||||
names := slices.Sorted(maps.Keys(integrations))
|
||||
var items []selectItem
|
||||
for _, name := range names {
|
||||
r := integrations[name]
|
||||
description := r.String()
|
||||
if conn, err := loadIntegration(name); err == nil && len(conn.Models) > 0 {
|
||||
description = fmt.Sprintf("%s (%s)", r.String(), conn.Models[0])
|
||||
}
|
||||
items = append(items, selectItem{Name: name, Description: description})
|
||||
}
|
||||
|
||||
return selectPrompt("Select integration:", items)
|
||||
}
|
||||
|
||||
// selectModels lets the user select models for an integration
|
||||
func selectModels(ctx context.Context, name, current string) ([]string, error) {
|
||||
r, ok := integrations[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
models, err := client.List(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(models.Models) == 0 {
|
||||
return nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
|
||||
var items []selectItem
|
||||
cloudModels := make(map[string]bool)
|
||||
for _, m := range models.Models {
|
||||
if m.RemoteModel != "" {
|
||||
cloudModels[m.Name] = true
|
||||
}
|
||||
items = append(items, selectItem{Name: m.Name})
|
||||
}
|
||||
|
||||
if len(items) == 0 {
|
||||
return nil, fmt.Errorf("no local models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
|
||||
// Get previously configured models (saved config takes precedence)
|
||||
var preChecked []string
|
||||
if saved, err := loadIntegration(name); err == nil {
|
||||
preChecked = saved.Models
|
||||
} else if editor, ok := r.(Editor); ok {
|
||||
preChecked = editor.Models()
|
||||
}
|
||||
checked := make(map[string]bool, len(preChecked))
|
||||
for _, n := range preChecked {
|
||||
checked[n] = true
|
||||
}
|
||||
|
||||
// Resolve current to full name (e.g., "llama3.2" -> "llama3.2:latest")
|
||||
for _, item := range items {
|
||||
if item.Name == current || strings.HasPrefix(item.Name, current+":") {
|
||||
current = item.Name
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If current model is configured, move to front of preChecked
|
||||
if checked[current] {
|
||||
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
|
||||
}
|
||||
|
||||
// Sort: checked first, then alphabetical
|
||||
slices.SortFunc(items, func(a, b selectItem) int {
|
||||
ac, bc := checked[a.Name], checked[b.Name]
|
||||
if ac != bc {
|
||||
if ac {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
|
||||
})
|
||||
|
||||
var selected []string
|
||||
// only editors support multi-model selection
|
||||
if _, ok := r.(Editor); ok {
|
||||
selected, err = multiSelectPrompt(fmt.Sprintf("Select models for %s:", r), items, preChecked)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
model, err := selectPrompt(fmt.Sprintf("Select model for %s:", r), items)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selected = []string{model}
|
||||
}
|
||||
|
||||
// if any model in selected is a cloud model, ensure signed in
|
||||
var selectedCloudModels []string
|
||||
for _, m := range selected {
|
||||
if cloudModels[m] {
|
||||
selectedCloudModels = append(selectedCloudModels, m)
|
||||
}
|
||||
}
|
||||
if len(selectedCloudModels) > 0 {
|
||||
// ensure user is signed in
|
||||
user, err := client.Whoami(ctx)
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
var aErr api.AuthorizationError
|
||||
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modelList := strings.Join(selectedCloudModels, ", ")
|
||||
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
|
||||
if err != nil || !yes {
|
||||
return nil, fmt.Errorf("%s requires sign in", modelList)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
|
||||
|
||||
// TODO(parthsareen): extract into auth package for cmd
|
||||
// Auto-open browser (best effort, fail silently)
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
_ = exec.Command("open", aErr.SigninURL).Start()
|
||||
case "linux":
|
||||
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
|
||||
case "windows":
|
||||
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
|
||||
}
|
||||
|
||||
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||
frame := 0
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
|
||||
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
return nil, ctx.Err()
|
||||
case <-ticker.C:
|
||||
frame++
|
||||
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||
|
||||
// poll every 10th frame (~2 seconds)
|
||||
if frame%10 == 0 {
|
||||
u, err := client.Whoami(ctx)
|
||||
if err == nil && u != nil && u.Name != "" {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
|
||||
return selected, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func runIntegration(name, modelName string) error {
|
||||
r, ok := integrations[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName)
|
||||
return r.Run(modelName)
|
||||
}
|
||||
|
||||
// ConfigCmd returns the cobra command for configuring integrations.
|
||||
func ConfigCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error) *cobra.Command {
|
||||
var modelFlag string
|
||||
var launchFlag bool
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "config [INTEGRATION]",
|
||||
Short: "Configure an external integration to use Ollama",
|
||||
Long: `Configure an external application to use Ollama models.
|
||||
|
||||
Supported integrations:
|
||||
claude Claude Code
|
||||
codex Codex
|
||||
droid Droid
|
||||
opencode OpenCode
|
||||
|
||||
Examples:
|
||||
ollama config
|
||||
ollama config claude
|
||||
ollama config droid --launch`,
|
||||
Args: cobra.MaximumNArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
var name string
|
||||
if len(args) > 0 {
|
||||
name = args[0]
|
||||
} else {
|
||||
var err error
|
||||
name, err = selectIntegration()
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
r, ok := integrations[strings.ToLower(name)]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
|
||||
// If --launch without --model, use saved config if available
|
||||
if launchFlag && modelFlag == "" {
|
||||
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
|
||||
return runIntegration(name, config.Models[0])
|
||||
}
|
||||
}
|
||||
|
||||
var models []string
|
||||
if modelFlag != "" {
|
||||
// When --model is specified, merge with existing models (new model becomes default)
|
||||
models = []string{modelFlag}
|
||||
if existing, err := loadIntegration(name); err == nil && len(existing.Models) > 0 {
|
||||
for _, m := range existing.Models {
|
||||
if m != modelFlag {
|
||||
models = append(models, m)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
var err error
|
||||
models, err = selectModels(cmd.Context(), name, "")
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if editor, isEditor := r.(Editor); isEditor {
|
||||
paths := editor.Paths()
|
||||
if len(paths) > 0 {
|
||||
fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", r)
|
||||
for _, p := range paths {
|
||||
fmt.Fprintf(os.Stderr, " %s\n", p)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", backupDir())
|
||||
|
||||
if ok, _ := confirmPrompt("Proceed?"); !ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := saveIntegration(name, models); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
|
||||
if editor, isEditor := r.(Editor); isEditor {
|
||||
if err := editor.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if _, isEditor := r.(Editor); isEditor {
|
||||
if len(models) == 1 {
|
||||
fmt.Fprintf(os.Stderr, "Added %s to %s\n", models[0], r)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Added %d models to %s (default: %s)\n", len(models), r, models[0])
|
||||
}
|
||||
}
|
||||
|
||||
if slices.ContainsFunc(models, func(m string) bool {
|
||||
return !strings.HasSuffix(m, "cloud")
|
||||
}) {
|
||||
fmt.Fprintln(os.Stderr)
|
||||
fmt.Fprintln(os.Stderr, "Coding agents work best with at least 64k context. Either:")
|
||||
fmt.Fprintln(os.Stderr, " - Set the context slider in Ollama app settings")
|
||||
fmt.Fprintln(os.Stderr, " - Run: OLLAMA_CONTEXT_LENGTH=64000 ollama serve")
|
||||
}
|
||||
|
||||
if launchFlag {
|
||||
return runIntegration(name, models[0])
|
||||
}
|
||||
|
||||
if launch, _ := confirmPrompt(fmt.Sprintf("\nLaunch %s now?", r)); launch {
|
||||
return runIntegration(name, models[0])
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Run 'ollama config %s --launch' to start with %s\n", strings.ToLower(name), models[0])
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&modelFlag, "model", "", "Model to use")
|
||||
cmd.Flags().BoolVar(&launchFlag, "launch", false, "Launch the integration after configuring")
|
||||
return cmd
|
||||
}
|
||||
188
cmd/config/integrations_test.go
Normal file
188
cmd/config/integrations_test.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func TestIntegrationLookup(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantFound bool
|
||||
wantName string
|
||||
}{
|
||||
{"claude lowercase", "claude", true, "Claude Code"},
|
||||
{"claude uppercase", "CLAUDE", true, "Claude Code"},
|
||||
{"claude mixed case", "Claude", true, "Claude Code"},
|
||||
{"codex", "codex", true, "Codex"},
|
||||
{"droid", "droid", true, "Droid"},
|
||||
{"opencode", "opencode", true, "OpenCode"},
|
||||
{"unknown integration", "unknown", false, ""},
|
||||
{"empty string", "", false, ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r, found := integrations[strings.ToLower(tt.input)]
|
||||
if found != tt.wantFound {
|
||||
t.Errorf("integrations[%q] found = %v, want %v", tt.input, found, tt.wantFound)
|
||||
}
|
||||
if found && r.String() != tt.wantName {
|
||||
t.Errorf("integrations[%q].String() = %q, want %q", tt.input, r.String(), tt.wantName)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegrationRegistry(t *testing.T) {
|
||||
expectedIntegrations := []string{"claude", "codex", "droid", "opencode"}
|
||||
|
||||
for _, name := range expectedIntegrations {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
r, ok := integrations[name]
|
||||
if !ok {
|
||||
t.Fatalf("integration %q not found in registry", name)
|
||||
}
|
||||
if r.String() == "" {
|
||||
t.Error("integration.String() should not be empty")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasLocalModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
models []string
|
||||
want bool
|
||||
}{
|
||||
{"empty list", []string{}, false},
|
||||
{"single local model", []string{"llama3.2"}, true},
|
||||
{"single cloud model", []string{"cloud-model"}, false},
|
||||
{"mixed models", []string{"cloud-model", "llama3.2"}, true},
|
||||
{"multiple local models", []string{"llama3.2", "qwen2.5"}, true},
|
||||
{"multiple cloud models", []string{"cloud-a", "cloud-b"}, false},
|
||||
{"local model first", []string{"llama3.2", "cloud-model"}, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := slices.ContainsFunc(tt.models, func(m string) bool {
|
||||
return !strings.Contains(m, "cloud")
|
||||
})
|
||||
if got != tt.want {
|
||||
t.Errorf("hasLocalModel(%v) = %v, want %v", tt.models, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigCmd(t *testing.T) {
|
||||
// Mock checkServerHeartbeat that always succeeds
|
||||
mockCheck := func(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd := ConfigCmd(mockCheck)
|
||||
|
||||
t.Run("command structure", func(t *testing.T) {
|
||||
if cmd.Use != "config [INTEGRATION]" {
|
||||
t.Errorf("Use = %q, want %q", cmd.Use, "config [INTEGRATION]")
|
||||
}
|
||||
if cmd.Short == "" {
|
||||
t.Error("Short description should not be empty")
|
||||
}
|
||||
if cmd.Long == "" {
|
||||
t.Error("Long description should not be empty")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("flags exist", func(t *testing.T) {
|
||||
modelFlag := cmd.Flags().Lookup("model")
|
||||
if modelFlag == nil {
|
||||
t.Error("--model flag should exist")
|
||||
}
|
||||
|
||||
launchFlag := cmd.Flags().Lookup("launch")
|
||||
if launchFlag == nil {
|
||||
t.Error("--launch flag should exist")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PreRunE is set", func(t *testing.T) {
|
||||
if cmd.PreRunE == nil {
|
||||
t.Error("PreRunE should be set to checkServerHeartbeat")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRunIntegration_UnknownIntegration(t *testing.T) {
|
||||
err := runIntegration("unknown-integration", "model")
|
||||
if err == nil {
|
||||
t.Error("expected error for unknown integration, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unknown integration") {
|
||||
t.Errorf("error should mention 'unknown integration', got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasLocalModel_DocumentsHeuristic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
models []string
|
||||
want bool
|
||||
reason string
|
||||
}{
|
||||
{"empty list", []string{}, false, "empty list has no local models"},
|
||||
{"contains-cloud-substring", []string{"deepseek-r1:cloud"}, false, "model with 'cloud' substring is considered cloud"},
|
||||
{"cloud-in-name", []string{"my-cloud-model"}, false, "'cloud' anywhere in name = cloud model"},
|
||||
{"cloudless", []string{"cloudless-model"}, false, "'cloudless' still contains 'cloud'"},
|
||||
{"local-model", []string{"llama3.2"}, true, "no 'cloud' = local"},
|
||||
{"mixed", []string{"cloud-model", "llama3.2"}, true, "one local model = hasLocalModel true"},
|
||||
{"all-cloud", []string{"cloud-a", "cloud-b"}, false, "all contain 'cloud'"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := slices.ContainsFunc(tt.models, func(m string) bool {
|
||||
return !strings.Contains(m, "cloud")
|
||||
})
|
||||
if got != tt.want {
|
||||
t.Errorf("hasLocalModel(%v) = %v, want %v (%s)", tt.models, got, tt.want, tt.reason)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigCmd_NilHeartbeat(t *testing.T) {
|
||||
// This should not panic - cmd creation should work even with nil
|
||||
cmd := ConfigCmd(nil)
|
||||
if cmd == nil {
|
||||
t.Fatal("ConfigCmd returned nil")
|
||||
}
|
||||
|
||||
// PreRunE should be nil when passed nil
|
||||
if cmd.PreRunE != nil {
|
||||
t.Log("Note: PreRunE is set even when nil is passed (acceptable)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllIntegrations_HaveRequiredMethods(t *testing.T) {
|
||||
for name, r := range integrations {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
// Test String() doesn't panic and returns non-empty
|
||||
displayName := r.String()
|
||||
if displayName == "" {
|
||||
t.Error("String() should not return empty")
|
||||
}
|
||||
|
||||
// Test Run() exists (we can't call it without actually running the command)
|
||||
// Just verify the method is available
|
||||
var _ func(string) error = r.Run
|
||||
})
|
||||
}
|
||||
}
|
||||
203
cmd/config/opencode.go
Normal file
203
cmd/config/opencode.go
Normal file
@@ -0,0 +1,203 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// OpenCode implements Runner and Editor for OpenCode integration
|
||||
type OpenCode struct{}
|
||||
|
||||
func (o *OpenCode) String() string { return "OpenCode" }
|
||||
|
||||
func (o *OpenCode) Run(model string) error {
|
||||
if _, err := exec.LookPath("opencode"); err != nil {
|
||||
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
|
||||
}
|
||||
|
||||
// Call Edit() to ensure config is up-to-date before launch
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("opencode"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
if err := o.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("opencode")
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func (o *OpenCode) Paths() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var paths []string
|
||||
p := filepath.Join(home, ".config", "opencode", "opencode.json")
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
paths = append(paths, p)
|
||||
}
|
||||
sp := filepath.Join(home, ".local", "state", "opencode", "model.json")
|
||||
if _, err := os.Stat(sp); err == nil {
|
||||
paths = append(paths, sp)
|
||||
}
|
||||
return paths
|
||||
}
|
||||
|
||||
func (o *OpenCode) Edit(modelList []string) error {
|
||||
if len(modelList) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".config", "opencode", "opencode.json")
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config := make(map[string]any)
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config) // Ignore parse errors; treat missing/corrupt files as empty
|
||||
}
|
||||
|
||||
config["$schema"] = "https://opencode.ai/config.json"
|
||||
|
||||
provider, ok := config["provider"].(map[string]any)
|
||||
if !ok {
|
||||
provider = make(map[string]any)
|
||||
}
|
||||
|
||||
ollama, ok := provider["ollama"].(map[string]any)
|
||||
if !ok {
|
||||
ollama = map[string]any{
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama (local)",
|
||||
"options": map[string]any{
|
||||
"baseURL": "http://localhost:11434/v1",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
models, ok := ollama["models"].(map[string]any)
|
||||
if !ok {
|
||||
models = make(map[string]any)
|
||||
}
|
||||
|
||||
selectedSet := make(map[string]bool)
|
||||
for _, m := range modelList {
|
||||
selectedSet[m] = true
|
||||
}
|
||||
|
||||
for name, cfg := range models {
|
||||
if cfgMap, ok := cfg.(map[string]any); ok {
|
||||
if displayName, ok := cfgMap["name"].(string); ok {
|
||||
if strings.HasSuffix(displayName, "[Ollama]") && !selectedSet[name] {
|
||||
delete(models, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, model := range modelList {
|
||||
models[model] = map[string]any{
|
||||
"name": fmt.Sprintf("%s [Ollama]", model),
|
||||
}
|
||||
}
|
||||
|
||||
ollama["models"] = models
|
||||
provider["ollama"] = ollama
|
||||
config["provider"] = provider
|
||||
|
||||
configData, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeWithBackup(configPath, configData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
statePath := filepath.Join(home, ".local", "state", "opencode", "model.json")
|
||||
if err := os.MkdirAll(filepath.Dir(statePath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
state := map[string]any{
|
||||
"recent": []any{},
|
||||
"favorite": []any{},
|
||||
"variant": map[string]any{},
|
||||
}
|
||||
if data, err := os.ReadFile(statePath); err == nil {
|
||||
_ = json.Unmarshal(data, &state) // Ignore parse errors; use defaults
|
||||
}
|
||||
|
||||
recent, _ := state["recent"].([]any)
|
||||
|
||||
modelSet := make(map[string]bool)
|
||||
for _, m := range modelList {
|
||||
modelSet[m] = true
|
||||
}
|
||||
|
||||
// Filter out existing Ollama models we're about to re-add
|
||||
newRecent := slices.DeleteFunc(slices.Clone(recent), func(entry any) bool {
|
||||
e, ok := entry.(map[string]any)
|
||||
if !ok || e["providerID"] != "ollama" {
|
||||
return false
|
||||
}
|
||||
modelID, _ := e["modelID"].(string)
|
||||
return modelSet[modelID]
|
||||
})
|
||||
|
||||
// Prepend models in reverse order so first model ends up first
|
||||
for _, model := range slices.Backward(modelList) {
|
||||
newRecent = slices.Insert(newRecent, 0, any(map[string]any{
|
||||
"providerID": "ollama",
|
||||
"modelID": model,
|
||||
}))
|
||||
}
|
||||
|
||||
const maxRecentModels = 10
|
||||
newRecent = newRecent[:min(len(newRecent), maxRecentModels)]
|
||||
|
||||
state["recent"] = newRecent
|
||||
|
||||
stateData, err := json.MarshalIndent(state, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(statePath, stateData)
|
||||
}
|
||||
|
||||
func (o *OpenCode) Models() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
config, err := readJSONFile(filepath.Join(home, ".config", "opencode", "opencode.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
provider, _ := config["provider"].(map[string]any)
|
||||
ollama, _ := provider["ollama"].(map[string]any)
|
||||
models, _ := ollama["models"].(map[string]any)
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
keys := slices.Collect(maps.Keys(models))
|
||||
slices.Sort(keys)
|
||||
return keys
|
||||
}
|
||||
437
cmd/config/opencode_test.go
Normal file
437
cmd/config/opencode_test.go
Normal file
@@ -0,0 +1,437 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestOpenCodeIntegration(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := o.String(); got != "OpenCode" {
|
||||
t.Errorf("String() = %q, want %q", got, "OpenCode")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = o
|
||||
})
|
||||
|
||||
t.Run("implements Editor", func(t *testing.T) {
|
||||
var _ Editor = o
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||
statePath := filepath.Join(stateDir, "model.json")
|
||||
|
||||
cleanup := func() {
|
||||
os.RemoveAll(configDir)
|
||||
os.RemoveAll(stateDir)
|
||||
}
|
||||
|
||||
t.Run("fresh install", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
|
||||
})
|
||||
|
||||
t.Run("preserve other providers", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"provider":{"anthropic":{"apiKey":"xxx"}}}`), 0o644)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider := cfg["provider"].(map[string]any)
|
||||
if provider["anthropic"] == nil {
|
||||
t.Error("anthropic provider was removed")
|
||||
}
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
})
|
||||
|
||||
t.Run("preserve other models", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"mistral":{"name":"Mistral"}}}}}`), 0o644)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertOpenCodeModelExists(t, configPath, "mistral")
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
})
|
||||
|
||||
t.Run("update existing model", func(t *testing.T) {
|
||||
cleanup()
|
||||
o.Edit([]string{"llama3.2"})
|
||||
o.Edit([]string{"llama3.2"})
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
})
|
||||
|
||||
t.Run("preserve top-level keys", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"theme":"dark","keybindings":{}}`), 0o644)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
if cfg["theme"] != "dark" {
|
||||
t.Error("theme was removed")
|
||||
}
|
||||
if cfg["keybindings"] == nil {
|
||||
t.Error("keybindings was removed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("model state - insert at index 0", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(statePath, []byte(`{"recent":[{"providerID":"anthropic","modelID":"claude"}],"favorite":[],"variant":{}}`), 0o644)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
|
||||
assertOpenCodeRecentModel(t, statePath, 1, "anthropic", "claude")
|
||||
})
|
||||
|
||||
t.Run("model state - preserve favorites and variants", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(statePath, []byte(`{"recent":[],"favorite":[{"providerID":"x","modelID":"y"}],"variant":{"a":"b"}}`), 0o644)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _ := os.ReadFile(statePath)
|
||||
var state map[string]any
|
||||
json.Unmarshal(data, &state)
|
||||
if len(state["favorite"].([]any)) != 1 {
|
||||
t.Error("favorite was modified")
|
||||
}
|
||||
if state["variant"].(map[string]any)["a"] != "b" {
|
||||
t.Error("variant was modified")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("model state - deduplicate on re-add", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(statePath, []byte(`{"recent":[{"providerID":"ollama","modelID":"llama3.2"},{"providerID":"anthropic","modelID":"claude"}],"favorite":[],"variant":{}}`), 0o644)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _ := os.ReadFile(statePath)
|
||||
var state map[string]any
|
||||
json.Unmarshal(data, &state)
|
||||
recent := state["recent"].([]any)
|
||||
if len(recent) != 2 {
|
||||
t.Errorf("expected 2 recent entries, got %d", len(recent))
|
||||
}
|
||||
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
|
||||
})
|
||||
|
||||
t.Run("remove model", func(t *testing.T) {
|
||||
cleanup()
|
||||
// First add two models
|
||||
o.Edit([]string{"llama3.2", "mistral"})
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
assertOpenCodeModelExists(t, configPath, "mistral")
|
||||
|
||||
// Then remove one by only selecting the other
|
||||
o.Edit([]string{"llama3.2"})
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
assertOpenCodeModelNotExists(t, configPath, "mistral")
|
||||
})
|
||||
|
||||
t.Run("remove model preserves non-ollama models", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
// Add a non-Ollama model manually
|
||||
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"external":{"name":"External Model"}}}}}`), 0o644)
|
||||
|
||||
o.Edit([]string{"llama3.2"})
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
assertOpenCodeModelExists(t, configPath, "external") // Should be preserved
|
||||
})
|
||||
}
|
||||
|
||||
func assertOpenCodeModelExists(t *testing.T, path, model string) {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
provider, ok := cfg["provider"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("provider not found")
|
||||
}
|
||||
ollama, ok := provider["ollama"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("ollama provider not found")
|
||||
}
|
||||
models, ok := ollama["models"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("models not found")
|
||||
}
|
||||
if models[model] == nil {
|
||||
t.Errorf("model %s not found", model)
|
||||
}
|
||||
}
|
||||
|
||||
func assertOpenCodeModelNotExists(t *testing.T, path, model string) {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
provider, ok := cfg["provider"].(map[string]any)
|
||||
if !ok {
|
||||
return // No provider means no model
|
||||
}
|
||||
ollama, ok := provider["ollama"].(map[string]any)
|
||||
if !ok {
|
||||
return // No ollama means no model
|
||||
}
|
||||
models, ok := ollama["models"].(map[string]any)
|
||||
if !ok {
|
||||
return // No models means no model
|
||||
}
|
||||
if models[model] != nil {
|
||||
t.Errorf("model %s should not exist but was found", model)
|
||||
}
|
||||
}
|
||||
|
||||
func assertOpenCodeRecentModel(t *testing.T, path string, index int, providerID, modelID string) {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var state map[string]any
|
||||
if err := json.Unmarshal(data, &state); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
recent, ok := state["recent"].([]any)
|
||||
if !ok {
|
||||
t.Fatal("recent not found")
|
||||
}
|
||||
if index >= len(recent) {
|
||||
t.Fatalf("index %d out of range (len=%d)", index, len(recent))
|
||||
}
|
||||
entry, ok := recent[index].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("entry is not a map")
|
||||
}
|
||||
if entry["providerID"] != providerID {
|
||||
t.Errorf("expected providerID %s, got %s", providerID, entry["providerID"])
|
||||
}
|
||||
if entry["modelID"] != modelID {
|
||||
t.Errorf("expected modelID %s, got %s", modelID, entry["modelID"])
|
||||
}
|
||||
}
|
||||
|
||||
// Edge case tests for opencode.go
|
||||
|
||||
func TestOpenCodeEdit_CorruptedConfigJSON(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{corrupted json content`), 0o644)
|
||||
|
||||
// Should not panic - corrupted JSON should be treated as empty
|
||||
err := o.Edit([]string{"llama3.2"})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit failed with corrupted config: %v", err)
|
||||
}
|
||||
|
||||
// Verify valid JSON was created
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Errorf("resulting config is not valid JSON: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_CorruptedStateJSON(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||
statePath := filepath.Join(stateDir, "model.json")
|
||||
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(statePath, []byte(`{corrupted state`), 0o644)
|
||||
|
||||
err := o.Edit([]string{"llama3.2"})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit failed with corrupted state: %v", err)
|
||||
}
|
||||
|
||||
// Verify valid state was created
|
||||
data, _ := os.ReadFile(statePath)
|
||||
var state map[string]any
|
||||
if err := json.Unmarshal(data, &state); err != nil {
|
||||
t.Errorf("resulting state is not valid JSON: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_WrongTypeProvider(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"provider": "not a map"}`), 0o644)
|
||||
|
||||
err := o.Edit([]string{"llama3.2"})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit with wrong type provider failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify provider is now correct type
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
|
||||
provider, ok := cfg["provider"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("provider should be map after setup, got %T", cfg["provider"])
|
||||
}
|
||||
if provider["ollama"] == nil {
|
||||
t.Error("ollama provider should be created")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_WrongTypeRecent(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||
statePath := filepath.Join(stateDir, "model.json")
|
||||
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(statePath, []byte(`{"recent": "not an array", "favorite": [], "variant": {}}`), 0o644)
|
||||
|
||||
err := o.Edit([]string{"llama3.2"})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit with wrong type recent failed: %v", err)
|
||||
}
|
||||
|
||||
// The function should handle this gracefully
|
||||
data, _ := os.ReadFile(statePath)
|
||||
var state map[string]any
|
||||
json.Unmarshal(data, &state)
|
||||
|
||||
// recent should be properly set after setup
|
||||
recent, ok := state["recent"].([]any)
|
||||
if !ok {
|
||||
t.Logf("Note: recent type after setup is %T (documenting behavior)", state["recent"])
|
||||
} else if len(recent) == 0 {
|
||||
t.Logf("Note: recent is empty (documenting behavior)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_EmptyModels(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
originalContent := `{"provider":{"ollama":{"models":{"existing":{}}}}}`
|
||||
os.WriteFile(configPath, []byte(originalContent), 0o644)
|
||||
|
||||
// Empty models should be no-op
|
||||
err := o.Edit([]string{})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit with empty models failed: %v", err)
|
||||
}
|
||||
|
||||
// Original content should be preserved (file not modified)
|
||||
data, _ := os.ReadFile(configPath)
|
||||
if string(data) != originalContent {
|
||||
t.Errorf("empty models should not modify file, but content changed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_SpecialCharsInModelName(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Model name with special characters (though unusual)
|
||||
specialModel := `model-with-"quotes"`
|
||||
|
||||
err := o.Edit([]string{specialModel})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit with special chars failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify it was stored correctly
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
data, _ := os.ReadFile(configPath)
|
||||
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Fatalf("resulting config is invalid JSON: %v", err)
|
||||
}
|
||||
|
||||
// Model should be accessible
|
||||
provider, _ := cfg["provider"].(map[string]any)
|
||||
ollama, _ := provider["ollama"].(map[string]any)
|
||||
models, _ := ollama["models"].(map[string]any)
|
||||
|
||||
if models[specialModel] == nil {
|
||||
t.Errorf("model with special chars not found in config")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeModels_NoConfig(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
models := o.Models()
|
||||
if len(models) > 0 {
|
||||
t.Errorf("expected nil/empty for missing config, got %v", models)
|
||||
}
|
||||
}
|
||||
499
cmd/config/selector.go
Normal file
499
cmd/config/selector.go
Normal file
@@ -0,0 +1,499 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// ANSI escape sequences for terminal formatting.
|
||||
const (
|
||||
ansiHideCursor = "\033[?25l"
|
||||
ansiShowCursor = "\033[?25h"
|
||||
ansiBold = "\033[1m"
|
||||
ansiReset = "\033[0m"
|
||||
ansiGray = "\033[37m"
|
||||
ansiClearDown = "\033[J"
|
||||
)
|
||||
|
||||
const maxDisplayedItems = 10
|
||||
|
||||
var errCancelled = errors.New("cancelled")
|
||||
|
||||
type selectItem struct {
|
||||
Name string
|
||||
Description string
|
||||
}
|
||||
|
||||
type inputEvent int
|
||||
|
||||
const (
|
||||
eventNone inputEvent = iota
|
||||
eventEnter
|
||||
eventEscape
|
||||
eventUp
|
||||
eventDown
|
||||
eventTab
|
||||
eventBackspace
|
||||
eventChar
|
||||
)
|
||||
|
||||
type selectState struct {
|
||||
items []selectItem
|
||||
filter string
|
||||
selected int
|
||||
scrollOffset int
|
||||
}
|
||||
|
||||
func newSelectState(items []selectItem) *selectState {
|
||||
return &selectState{items: items}
|
||||
}
|
||||
|
||||
func (s *selectState) filtered() []selectItem {
|
||||
return filterItems(s.items, s.filter)
|
||||
}
|
||||
|
||||
func (s *selectState) handleInput(event inputEvent, char byte) (done bool, result string, err error) {
|
||||
filtered := s.filtered()
|
||||
|
||||
switch event {
|
||||
case eventEnter:
|
||||
if len(filtered) > 0 && s.selected < len(filtered) {
|
||||
return true, filtered[s.selected].Name, nil
|
||||
}
|
||||
case eventEscape:
|
||||
return true, "", errCancelled
|
||||
case eventBackspace:
|
||||
if len(s.filter) > 0 {
|
||||
s.filter = s.filter[:len(s.filter)-1]
|
||||
s.selected = 0
|
||||
s.scrollOffset = 0
|
||||
}
|
||||
case eventUp:
|
||||
if s.selected > 0 {
|
||||
s.selected--
|
||||
if s.selected < s.scrollOffset {
|
||||
s.scrollOffset = s.selected
|
||||
}
|
||||
}
|
||||
case eventDown:
|
||||
if s.selected < len(filtered)-1 {
|
||||
s.selected++
|
||||
if s.selected >= s.scrollOffset+maxDisplayedItems {
|
||||
s.scrollOffset = s.selected - maxDisplayedItems + 1
|
||||
}
|
||||
}
|
||||
case eventChar:
|
||||
s.filter += string(char)
|
||||
s.selected = 0
|
||||
s.scrollOffset = 0
|
||||
}
|
||||
|
||||
return false, "", nil
|
||||
}
|
||||
|
||||
type multiSelectState struct {
|
||||
items []selectItem
|
||||
itemIndex map[string]int
|
||||
filter string
|
||||
highlighted int
|
||||
scrollOffset int
|
||||
checked map[int]bool
|
||||
checkOrder []int
|
||||
focusOnButton bool
|
||||
}
|
||||
|
||||
func newMultiSelectState(items []selectItem, preChecked []string) *multiSelectState {
|
||||
s := &multiSelectState{
|
||||
items: items,
|
||||
itemIndex: make(map[string]int, len(items)),
|
||||
checked: make(map[int]bool),
|
||||
}
|
||||
|
||||
for i, item := range items {
|
||||
s.itemIndex[item.Name] = i
|
||||
}
|
||||
|
||||
for _, name := range preChecked {
|
||||
if idx, ok := s.itemIndex[name]; ok {
|
||||
s.checked[idx] = true
|
||||
s.checkOrder = append(s.checkOrder, idx)
|
||||
}
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *multiSelectState) filtered() []selectItem {
|
||||
return filterItems(s.items, s.filter)
|
||||
}
|
||||
|
||||
func (s *multiSelectState) toggleItem() {
|
||||
filtered := s.filtered()
|
||||
if len(filtered) == 0 || s.highlighted >= len(filtered) {
|
||||
return
|
||||
}
|
||||
|
||||
item := filtered[s.highlighted]
|
||||
origIdx := s.itemIndex[item.Name]
|
||||
|
||||
if s.checked[origIdx] {
|
||||
delete(s.checked, origIdx)
|
||||
for i, idx := range s.checkOrder {
|
||||
if idx == origIdx {
|
||||
s.checkOrder = append(s.checkOrder[:i], s.checkOrder[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
s.checked[origIdx] = true
|
||||
s.checkOrder = append(s.checkOrder, origIdx)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *multiSelectState) handleInput(event inputEvent, char byte) (done bool, result []string, err error) {
|
||||
filtered := s.filtered()
|
||||
|
||||
switch event {
|
||||
case eventEnter:
|
||||
if s.focusOnButton && len(s.checkOrder) > 0 {
|
||||
var res []string
|
||||
for _, idx := range s.checkOrder {
|
||||
res = append(res, s.items[idx].Name)
|
||||
}
|
||||
return true, res, nil
|
||||
} else if !s.focusOnButton {
|
||||
s.toggleItem()
|
||||
}
|
||||
case eventTab:
|
||||
if len(s.checkOrder) > 0 {
|
||||
s.focusOnButton = !s.focusOnButton
|
||||
}
|
||||
case eventEscape:
|
||||
return true, nil, errCancelled
|
||||
case eventBackspace:
|
||||
if len(s.filter) > 0 {
|
||||
s.filter = s.filter[:len(s.filter)-1]
|
||||
s.highlighted = 0
|
||||
s.scrollOffset = 0
|
||||
s.focusOnButton = false
|
||||
}
|
||||
case eventUp:
|
||||
if s.focusOnButton {
|
||||
s.focusOnButton = false
|
||||
} else if s.highlighted > 0 {
|
||||
s.highlighted--
|
||||
if s.highlighted < s.scrollOffset {
|
||||
s.scrollOffset = s.highlighted
|
||||
}
|
||||
}
|
||||
case eventDown:
|
||||
if s.focusOnButton {
|
||||
s.focusOnButton = false
|
||||
} else if s.highlighted < len(filtered)-1 {
|
||||
s.highlighted++
|
||||
if s.highlighted >= s.scrollOffset+maxDisplayedItems {
|
||||
s.scrollOffset = s.highlighted - maxDisplayedItems + 1
|
||||
}
|
||||
}
|
||||
case eventChar:
|
||||
s.filter += string(char)
|
||||
s.highlighted = 0
|
||||
s.scrollOffset = 0
|
||||
s.focusOnButton = false
|
||||
}
|
||||
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
func (s *multiSelectState) selectedCount() int {
|
||||
return len(s.checkOrder)
|
||||
}
|
||||
|
||||
// Terminal I/O handling
|
||||
|
||||
type terminalState struct {
|
||||
fd int
|
||||
oldState *term.State
|
||||
}
|
||||
|
||||
func enterRawMode() (*terminalState, error) {
|
||||
fd := int(os.Stdin.Fd())
|
||||
oldState, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Fprint(os.Stderr, ansiHideCursor)
|
||||
return &terminalState{fd: fd, oldState: oldState}, nil
|
||||
}
|
||||
|
||||
func (t *terminalState) restore() {
|
||||
fmt.Fprint(os.Stderr, ansiShowCursor)
|
||||
term.Restore(t.fd, t.oldState)
|
||||
}
|
||||
|
||||
func clearLines(n int) {
|
||||
if n > 0 {
|
||||
fmt.Fprintf(os.Stderr, "\033[%dA", n)
|
||||
fmt.Fprint(os.Stderr, ansiClearDown)
|
||||
}
|
||||
}
|
||||
|
||||
func parseInput(r io.Reader) (inputEvent, byte, error) {
|
||||
buf := make([]byte, 3)
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
switch {
|
||||
case n == 1 && buf[0] == 13:
|
||||
return eventEnter, 0, nil
|
||||
case n == 1 && (buf[0] == 3 || buf[0] == 27):
|
||||
return eventEscape, 0, nil
|
||||
case n == 1 && buf[0] == 9:
|
||||
return eventTab, 0, nil
|
||||
case n == 1 && buf[0] == 127:
|
||||
return eventBackspace, 0, nil
|
||||
case n == 3 && buf[0] == 27 && buf[1] == 91 && buf[2] == 65:
|
||||
return eventUp, 0, nil
|
||||
case n == 3 && buf[0] == 27 && buf[1] == 91 && buf[2] == 66:
|
||||
return eventDown, 0, nil
|
||||
case n == 1 && buf[0] >= 32 && buf[0] < 127:
|
||||
return eventChar, buf[0], nil
|
||||
}
|
||||
|
||||
return eventNone, 0, nil
|
||||
}
|
||||
|
||||
// Rendering
|
||||
|
||||
func renderSelect(w io.Writer, prompt string, s *selectState) int {
|
||||
filtered := s.filtered()
|
||||
|
||||
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
|
||||
lineCount := 1
|
||||
|
||||
if len(filtered) == 0 {
|
||||
fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset)
|
||||
lineCount++
|
||||
} else {
|
||||
displayCount := min(len(filtered), maxDisplayedItems)
|
||||
|
||||
for i := range displayCount {
|
||||
idx := s.scrollOffset + i
|
||||
if idx >= len(filtered) {
|
||||
break
|
||||
}
|
||||
item := filtered[idx]
|
||||
prefix := " "
|
||||
if idx == s.selected {
|
||||
prefix = " " + ansiBold + "> "
|
||||
}
|
||||
if item.Description != "" {
|
||||
fmt.Fprintf(w, "%s%s%s %s- %s%s\r\n", prefix, item.Name, ansiReset, ansiGray, item.Description, ansiReset)
|
||||
} else {
|
||||
fmt.Fprintf(w, "%s%s%s\r\n", prefix, item.Name, ansiReset)
|
||||
}
|
||||
lineCount++
|
||||
}
|
||||
|
||||
if remaining := len(filtered) - s.scrollOffset - displayCount; remaining > 0 {
|
||||
fmt.Fprintf(w, " %s... and %d more%s\r\n", ansiGray, remaining, ansiReset)
|
||||
lineCount++
|
||||
}
|
||||
}
|
||||
|
||||
return lineCount
|
||||
}
|
||||
|
||||
func renderMultiSelect(w io.Writer, prompt string, s *multiSelectState) int {
|
||||
filtered := s.filtered()
|
||||
|
||||
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
|
||||
lineCount := 1
|
||||
|
||||
if len(filtered) == 0 {
|
||||
fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset)
|
||||
lineCount++
|
||||
} else {
|
||||
displayCount := min(len(filtered), maxDisplayedItems)
|
||||
|
||||
for i := range displayCount {
|
||||
idx := s.scrollOffset + i
|
||||
if idx >= len(filtered) {
|
||||
break
|
||||
}
|
||||
item := filtered[idx]
|
||||
origIdx := s.itemIndex[item.Name]
|
||||
|
||||
checkbox := "[ ]"
|
||||
if s.checked[origIdx] {
|
||||
checkbox = "[x]"
|
||||
}
|
||||
|
||||
prefix := " "
|
||||
suffix := ""
|
||||
if idx == s.highlighted && !s.focusOnButton {
|
||||
prefix = "> "
|
||||
}
|
||||
if len(s.checkOrder) > 0 && s.checkOrder[0] == origIdx {
|
||||
suffix = " " + ansiGray + "(default)" + ansiReset
|
||||
}
|
||||
|
||||
if idx == s.highlighted && !s.focusOnButton {
|
||||
fmt.Fprintf(w, " %s%s %s %s%s%s\r\n", ansiBold, prefix, checkbox, item.Name, ansiReset, suffix)
|
||||
} else {
|
||||
fmt.Fprintf(w, " %s %s %s%s\r\n", prefix, checkbox, item.Name, suffix)
|
||||
}
|
||||
lineCount++
|
||||
}
|
||||
|
||||
if remaining := len(filtered) - s.scrollOffset - displayCount; remaining > 0 {
|
||||
fmt.Fprintf(w, " %s... and %d more%s\r\n", ansiGray, remaining, ansiReset)
|
||||
lineCount++
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "\r\n")
|
||||
lineCount++
|
||||
count := s.selectedCount()
|
||||
switch {
|
||||
case count == 0:
|
||||
fmt.Fprintf(w, " %sSelect at least one model.%s\r\n", ansiGray, ansiReset)
|
||||
case s.focusOnButton:
|
||||
fmt.Fprintf(w, " %s> [ Continue ]%s %s(%d selected)%s\r\n", ansiBold, ansiReset, ansiGray, count, ansiReset)
|
||||
default:
|
||||
fmt.Fprintf(w, " %s[ Continue ] (%d selected) - press Tab%s\r\n", ansiGray, count, ansiReset)
|
||||
}
|
||||
lineCount++
|
||||
|
||||
return lineCount
|
||||
}
|
||||
|
||||
// selectPrompt prompts the user to select a single item from a list.
|
||||
func selectPrompt(prompt string, items []selectItem) (string, error) {
|
||||
if len(items) == 0 {
|
||||
return "", fmt.Errorf("no items to select from")
|
||||
}
|
||||
|
||||
ts, err := enterRawMode()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer ts.restore()
|
||||
|
||||
state := newSelectState(items)
|
||||
var lastLineCount int
|
||||
|
||||
render := func() {
|
||||
clearLines(lastLineCount)
|
||||
lastLineCount = renderSelect(os.Stderr, prompt, state)
|
||||
}
|
||||
|
||||
render()
|
||||
|
||||
for {
|
||||
event, char, err := parseInput(os.Stdin)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
done, result, err := state.handleInput(event, char)
|
||||
if done {
|
||||
clearLines(lastLineCount)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
render()
|
||||
}
|
||||
}
|
||||
|
||||
// multiSelectPrompt prompts the user to select multiple items from a list.
|
||||
func multiSelectPrompt(prompt string, items []selectItem, preChecked []string) ([]string, error) {
|
||||
if len(items) == 0 {
|
||||
return nil, fmt.Errorf("no items to select from")
|
||||
}
|
||||
|
||||
ts, err := enterRawMode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer ts.restore()
|
||||
|
||||
state := newMultiSelectState(items, preChecked)
|
||||
var lastLineCount int
|
||||
|
||||
render := func() {
|
||||
clearLines(lastLineCount)
|
||||
lastLineCount = renderMultiSelect(os.Stderr, prompt, state)
|
||||
}
|
||||
|
||||
render()
|
||||
|
||||
for {
|
||||
event, char, err := parseInput(os.Stdin)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
done, result, err := state.handleInput(event, char)
|
||||
if done {
|
||||
clearLines(lastLineCount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
render()
|
||||
}
|
||||
}
|
||||
|
||||
func confirmPrompt(prompt string) (bool, error) {
|
||||
fd := int(os.Stdin.Fd())
|
||||
oldState, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer term.Restore(fd, oldState)
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%s [y/n] ", prompt)
|
||||
|
||||
buf := make([]byte, 1)
|
||||
for {
|
||||
if _, err := os.Stdin.Read(buf); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
switch buf[0] {
|
||||
case 'Y', 'y', 13:
|
||||
fmt.Fprintf(os.Stderr, "yes\r\n")
|
||||
return true, nil
|
||||
case 'N', 'n', 27, 3:
|
||||
fmt.Fprintf(os.Stderr, "no\r\n")
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func filterItems(items []selectItem, filter string) []selectItem {
|
||||
if filter == "" {
|
||||
return items
|
||||
}
|
||||
var result []selectItem
|
||||
filterLower := strings.ToLower(filter)
|
||||
for _, item := range items {
|
||||
if strings.Contains(strings.ToLower(item.Name), filterLower) {
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
913
cmd/config/selector_test.go
Normal file
913
cmd/config/selector_test.go
Normal file
@@ -0,0 +1,913 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFilterItems(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "llama3.2:latest"},
|
||||
{Name: "qwen2.5:7b"},
|
||||
{Name: "deepseek-v3:cloud"},
|
||||
{Name: "GPT-OSS:20b"},
|
||||
}
|
||||
|
||||
t.Run("EmptyFilter_ReturnsAllItems", func(t *testing.T) {
|
||||
result := filterItems(items, "")
|
||||
if len(result) != len(items) {
|
||||
t.Errorf("expected %d items, got %d", len(items), len(result))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CaseInsensitive_UppercaseFilterMatchesLowercase", func(t *testing.T) {
|
||||
result := filterItems(items, "LLAMA")
|
||||
if len(result) != 1 || result[0].Name != "llama3.2:latest" {
|
||||
t.Errorf("expected llama3.2:latest, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CaseInsensitive_LowercaseFilterMatchesUppercase", func(t *testing.T) {
|
||||
result := filterItems(items, "gpt")
|
||||
if len(result) != 1 || result[0].Name != "GPT-OSS:20b" {
|
||||
t.Errorf("expected GPT-OSS:20b, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PartialMatch", func(t *testing.T) {
|
||||
result := filterItems(items, "deep")
|
||||
if len(result) != 1 || result[0].Name != "deepseek-v3:cloud" {
|
||||
t.Errorf("expected deepseek-v3:cloud, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NoMatch_ReturnsEmpty", func(t *testing.T) {
|
||||
result := filterItems(items, "nonexistent")
|
||||
if len(result) != 0 {
|
||||
t.Errorf("expected 0 items, got %d", len(result))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSelectState(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "item1"},
|
||||
{Name: "item2"},
|
||||
{Name: "item3"},
|
||||
}
|
||||
|
||||
t.Run("InitialState", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
if s.selected != 0 {
|
||||
t.Errorf("expected selected=0, got %d", s.selected)
|
||||
}
|
||||
if s.filter != "" {
|
||||
t.Errorf("expected empty filter, got %q", s.filter)
|
||||
}
|
||||
if s.scrollOffset != 0 {
|
||||
t.Errorf("expected scrollOffset=0, got %d", s.scrollOffset)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_SelectsCurrentItem", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if !done || result != "item1" || err != nil {
|
||||
t.Errorf("expected (true, item1, nil), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_WithFilter_SelectsFilteredItem", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.filter = "item3"
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if !done || result != "item3" || err != nil {
|
||||
t.Errorf("expected (true, item3, nil), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_EmptyFilteredList_DoesNothing", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.filter = "nonexistent"
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if done || result != "" || err != nil {
|
||||
t.Errorf("expected (false, '', nil), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
done, result, err := s.handleInput(eventEscape, 0)
|
||||
if !done || result != "" || err != errCancelled {
|
||||
t.Errorf("expected (true, '', errCancelled), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Down_MovesSelection", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.selected != 1 {
|
||||
t.Errorf("expected selected=1, got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Down_AtBottom_StaysAtBottom", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.selected = 2
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.selected != 2 {
|
||||
t.Errorf("expected selected=2 (stayed at bottom), got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Up_MovesSelection", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.selected = 2
|
||||
s.handleInput(eventUp, 0)
|
||||
if s.selected != 1 {
|
||||
t.Errorf("expected selected=1, got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Up_AtTop_StaysAtTop", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.handleInput(eventUp, 0)
|
||||
if s.selected != 0 {
|
||||
t.Errorf("expected selected=0 (stayed at top), got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Char_AppendsToFilter", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.handleInput(eventChar, 'i')
|
||||
s.handleInput(eventChar, 't')
|
||||
s.handleInput(eventChar, 'e')
|
||||
s.handleInput(eventChar, 'm')
|
||||
s.handleInput(eventChar, '2')
|
||||
if s.filter != "item2" {
|
||||
t.Errorf("expected filter='item2', got %q", s.filter)
|
||||
}
|
||||
filtered := s.filtered()
|
||||
if len(filtered) != 1 || filtered[0].Name != "item2" {
|
||||
t.Errorf("expected [item2], got %v", filtered)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Char_ResetsSelectionToZero", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.selected = 2
|
||||
s.handleInput(eventChar, 'x')
|
||||
if s.selected != 0 {
|
||||
t.Errorf("expected selected=0 after typing, got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace_RemovesLastFilterChar", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.filter = "test"
|
||||
s.handleInput(eventBackspace, 0)
|
||||
if s.filter != "tes" {
|
||||
t.Errorf("expected filter='tes', got %q", s.filter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace_EmptyFilter_DoesNothing", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.handleInput(eventBackspace, 0)
|
||||
if s.filter != "" {
|
||||
t.Errorf("expected filter='', got %q", s.filter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace_ResetsSelectionToZero", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.filter = "test"
|
||||
s.selected = 2
|
||||
s.handleInput(eventBackspace, 0)
|
||||
if s.selected != 0 {
|
||||
t.Errorf("expected selected=0 after backspace, got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Scroll_DownPastVisibleItems_ScrollsViewport", func(t *testing.T) {
|
||||
// maxDisplayedItems is 10, so with 15 items we need to scroll
|
||||
manyItems := make([]selectItem, 15)
|
||||
for i := range manyItems {
|
||||
manyItems[i] = selectItem{Name: string(rune('a' + i))}
|
||||
}
|
||||
s := newSelectState(manyItems)
|
||||
|
||||
// move down 12 times (past the 10-item viewport)
|
||||
for range 12 {
|
||||
s.handleInput(eventDown, 0)
|
||||
}
|
||||
|
||||
if s.selected != 12 {
|
||||
t.Errorf("expected selected=12, got %d", s.selected)
|
||||
}
|
||||
if s.scrollOffset != 3 {
|
||||
t.Errorf("expected scrollOffset=3 (12-10+1), got %d", s.scrollOffset)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Scroll_UpPastScrollOffset_ScrollsViewport", func(t *testing.T) {
|
||||
manyItems := make([]selectItem, 15)
|
||||
for i := range manyItems {
|
||||
manyItems[i] = selectItem{Name: string(rune('a' + i))}
|
||||
}
|
||||
s := newSelectState(manyItems)
|
||||
s.selected = 5
|
||||
s.scrollOffset = 5
|
||||
|
||||
s.handleInput(eventUp, 0)
|
||||
|
||||
if s.selected != 4 {
|
||||
t.Errorf("expected selected=4, got %d", s.selected)
|
||||
}
|
||||
if s.scrollOffset != 4 {
|
||||
t.Errorf("expected scrollOffset=4, got %d", s.scrollOffset)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMultiSelectState(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "item1"},
|
||||
{Name: "item2"},
|
||||
{Name: "item3"},
|
||||
}
|
||||
|
||||
t.Run("InitialState_NoPrechecked", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
if s.highlighted != 0 {
|
||||
t.Errorf("expected highlighted=0, got %d", s.highlighted)
|
||||
}
|
||||
if s.selectedCount() != 0 {
|
||||
t.Errorf("expected 0 selected, got %d", s.selectedCount())
|
||||
}
|
||||
if s.focusOnButton {
|
||||
t.Error("expected focusOnButton=false initially")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("InitialState_WithPrechecked", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item2", "item3"})
|
||||
if s.selectedCount() != 2 {
|
||||
t.Errorf("expected 2 selected, got %d", s.selectedCount())
|
||||
}
|
||||
if !s.checked[1] || !s.checked[2] {
|
||||
t.Error("expected item2 and item3 to be checked")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Prechecked_PreservesSelectionOrder", func(t *testing.T) {
|
||||
// order matters: first checked = default model
|
||||
s := newMultiSelectState(items, []string{"item3", "item1"})
|
||||
if len(s.checkOrder) != 2 {
|
||||
t.Fatalf("expected 2 in checkOrder, got %d", len(s.checkOrder))
|
||||
}
|
||||
if s.checkOrder[0] != 2 || s.checkOrder[1] != 0 {
|
||||
t.Errorf("expected checkOrder=[2,0] (item3 first), got %v", s.checkOrder)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Prechecked_IgnoresInvalidNames", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1", "nonexistent"})
|
||||
if s.selectedCount() != 1 {
|
||||
t.Errorf("expected 1 selected (nonexistent ignored), got %d", s.selectedCount())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Toggle_ChecksUncheckedItem", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.toggleItem()
|
||||
if !s.checked[0] {
|
||||
t.Error("expected item1 to be checked after toggle")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Toggle_UnchecksCheckedItem", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
s.toggleItem()
|
||||
if s.checked[0] {
|
||||
t.Error("expected item1 to be unchecked after toggle")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Toggle_RemovesFromCheckOrder", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1", "item2", "item3"})
|
||||
s.highlighted = 1 // toggle item2
|
||||
s.toggleItem()
|
||||
|
||||
if len(s.checkOrder) != 2 {
|
||||
t.Fatalf("expected 2 in checkOrder, got %d", len(s.checkOrder))
|
||||
}
|
||||
// should be [0, 2] (item1, item3) with item2 removed
|
||||
if s.checkOrder[0] != 0 || s.checkOrder[1] != 2 {
|
||||
t.Errorf("expected checkOrder=[0,2], got %v", s.checkOrder)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_TogglesWhenNotOnButton", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.handleInput(eventEnter, 0)
|
||||
if !s.checked[0] {
|
||||
t.Error("expected item1 to be checked after enter")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_OnButton_ReturnsSelection", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item2", "item1"})
|
||||
s.focusOnButton = true
|
||||
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
|
||||
if !done || err != nil {
|
||||
t.Errorf("expected done=true, err=nil, got done=%v, err=%v", done, err)
|
||||
}
|
||||
// result should preserve selection order
|
||||
if len(result) != 2 || result[0] != "item2" || result[1] != "item1" {
|
||||
t.Errorf("expected [item2, item1], got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_OnButton_EmptySelection_DoesNothing", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.focusOnButton = true
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if done || result != nil || err != nil {
|
||||
t.Errorf("expected (false, nil, nil), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Tab_SwitchesToButton_WhenHasSelection", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
s.handleInput(eventTab, 0)
|
||||
if !s.focusOnButton {
|
||||
t.Error("expected focus on button after tab")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Tab_DoesNothing_WhenNoSelection", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.handleInput(eventTab, 0)
|
||||
if s.focusOnButton {
|
||||
t.Error("tab should not focus button when nothing selected")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Tab_TogglesButtonFocus", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
s.handleInput(eventTab, 0)
|
||||
if !s.focusOnButton {
|
||||
t.Error("expected focus on button after first tab")
|
||||
}
|
||||
s.handleInput(eventTab, 0)
|
||||
if s.focusOnButton {
|
||||
t.Error("expected focus back on list after second tab")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
done, result, err := s.handleInput(eventEscape, 0)
|
||||
if !done || result != nil || err != errCancelled {
|
||||
t.Errorf("expected (true, nil, errCancelled), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IsDefault_TrueForFirstChecked", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item2", "item1"})
|
||||
if !(len(s.checkOrder) > 0 && s.checkOrder[0] == 1) {
|
||||
t.Error("expected item2 (idx 1) to be default (first checked)")
|
||||
}
|
||||
if len(s.checkOrder) > 0 && s.checkOrder[0] == 0 {
|
||||
t.Error("expected item1 (idx 0) to NOT be default")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IsDefault_FalseWhenNothingChecked", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
if len(s.checkOrder) > 0 && s.checkOrder[0] == 0 {
|
||||
t.Error("expected isDefault=false when nothing checked")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Down_MovesHighlight", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.highlighted != 1 {
|
||||
t.Errorf("expected highlighted=1, got %d", s.highlighted)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Up_MovesHighlight", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.highlighted = 1
|
||||
s.handleInput(eventUp, 0)
|
||||
if s.highlighted != 0 {
|
||||
t.Errorf("expected highlighted=0, got %d", s.highlighted)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Arrow_ReturnsFocusFromButton", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
s.focusOnButton = true
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.focusOnButton {
|
||||
t.Error("expected focus to return to list on arrow key")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Char_AppendsToFilter", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.handleInput(eventChar, 'x')
|
||||
if s.filter != "x" {
|
||||
t.Errorf("expected filter='x', got %q", s.filter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Char_ResetsHighlightAndScroll", func(t *testing.T) {
|
||||
manyItems := make([]selectItem, 15)
|
||||
for i := range manyItems {
|
||||
manyItems[i] = selectItem{Name: string(rune('a' + i))}
|
||||
}
|
||||
s := newMultiSelectState(manyItems, nil)
|
||||
s.highlighted = 10
|
||||
s.scrollOffset = 5
|
||||
|
||||
s.handleInput(eventChar, 'x')
|
||||
|
||||
if s.highlighted != 0 {
|
||||
t.Errorf("expected highlighted=0, got %d", s.highlighted)
|
||||
}
|
||||
if s.scrollOffset != 0 {
|
||||
t.Errorf("expected scrollOffset=0, got %d", s.scrollOffset)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace_RemovesLastFilterChar", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.filter = "test"
|
||||
s.handleInput(eventBackspace, 0)
|
||||
if s.filter != "tes" {
|
||||
t.Errorf("expected filter='tes', got %q", s.filter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace_RemovesFocusFromButton", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
s.filter = "x"
|
||||
s.focusOnButton = true
|
||||
s.handleInput(eventBackspace, 0)
|
||||
if s.focusOnButton {
|
||||
t.Error("expected focusOnButton=false after backspace")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseInput(t *testing.T) {
|
||||
t.Run("Enter", func(t *testing.T) {
|
||||
event, char, err := parseInput(bytes.NewReader([]byte{13}))
|
||||
if err != nil || event != eventEnter || char != 0 {
|
||||
t.Errorf("expected (eventEnter, 0, nil), got (%v, %v, %v)", event, char, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Escape", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{27}))
|
||||
if err != nil || event != eventEscape {
|
||||
t.Errorf("expected eventEscape, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CtrlC_TreatedAsEscape", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{3}))
|
||||
if err != nil || event != eventEscape {
|
||||
t.Errorf("expected eventEscape for Ctrl+C, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Tab", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{9}))
|
||||
if err != nil || event != eventTab {
|
||||
t.Errorf("expected eventTab, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{127}))
|
||||
if err != nil || event != eventBackspace {
|
||||
t.Errorf("expected eventBackspace, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UpArrow", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{27, 91, 65}))
|
||||
if err != nil || event != eventUp {
|
||||
t.Errorf("expected eventUp, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DownArrow", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{27, 91, 66}))
|
||||
if err != nil || event != eventDown {
|
||||
t.Errorf("expected eventDown, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PrintableChars", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
char byte
|
||||
}{
|
||||
{"lowercase", 'a'},
|
||||
{"uppercase", 'Z'},
|
||||
{"digit", '5'},
|
||||
{"space", ' '},
|
||||
{"tilde", '~'},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
event, char, err := parseInput(bytes.NewReader([]byte{tt.char}))
|
||||
if err != nil || event != eventChar || char != tt.char {
|
||||
t.Errorf("expected (eventChar, %q), got (%v, %q)", tt.char, event, char)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRenderSelect(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "item1", Description: "first item"},
|
||||
{Name: "item2"},
|
||||
}
|
||||
|
||||
t.Run("ShowsPromptAndItems", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
var buf bytes.Buffer
|
||||
lineCount := renderSelect(&buf, "Select:", s)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "Select:") {
|
||||
t.Error("expected prompt in output")
|
||||
}
|
||||
if !strings.Contains(output, "item1") {
|
||||
t.Error("expected item1 in output")
|
||||
}
|
||||
if !strings.Contains(output, "first item") {
|
||||
t.Error("expected description in output")
|
||||
}
|
||||
if !strings.Contains(output, "item2") {
|
||||
t.Error("expected item2 in output")
|
||||
}
|
||||
if lineCount != 3 { // 1 prompt + 2 items
|
||||
t.Errorf("expected 3 lines, got %d", lineCount)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EmptyFilteredList_ShowsNoMatches", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.filter = "xyz"
|
||||
var buf bytes.Buffer
|
||||
renderSelect(&buf, "Select:", s)
|
||||
|
||||
if !strings.Contains(buf.String(), "no matches") {
|
||||
t.Error("expected 'no matches' message")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LongList_ShowsRemainingCount", func(t *testing.T) {
|
||||
manyItems := make([]selectItem, 15)
|
||||
for i := range manyItems {
|
||||
manyItems[i] = selectItem{Name: string(rune('a' + i))}
|
||||
}
|
||||
s := newSelectState(manyItems)
|
||||
var buf bytes.Buffer
|
||||
renderSelect(&buf, "Select:", s)
|
||||
|
||||
// 15 items - 10 displayed = 5 more
|
||||
if !strings.Contains(buf.String(), "5 more") {
|
||||
t.Error("expected '5 more' indicator")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRenderMultiSelect(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "item1"},
|
||||
{Name: "item2"},
|
||||
}
|
||||
|
||||
t.Run("ShowsCheckboxes", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
var buf bytes.Buffer
|
||||
renderMultiSelect(&buf, "Select:", s)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "[x]") {
|
||||
t.Error("expected checked checkbox [x]")
|
||||
}
|
||||
if !strings.Contains(output, "[ ]") {
|
||||
t.Error("expected unchecked checkbox [ ]")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ShowsDefaultMarker", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
var buf bytes.Buffer
|
||||
renderMultiSelect(&buf, "Select:", s)
|
||||
|
||||
if !strings.Contains(buf.String(), "(default)") {
|
||||
t.Error("expected (default) marker for first checked item")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ShowsSelectedCount", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1", "item2"})
|
||||
var buf bytes.Buffer
|
||||
renderMultiSelect(&buf, "Select:", s)
|
||||
|
||||
if !strings.Contains(buf.String(), "2 selected") {
|
||||
t.Error("expected '2 selected' in output")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NoSelection_ShowsHelperText", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
var buf bytes.Buffer
|
||||
renderMultiSelect(&buf, "Select:", s)
|
||||
|
||||
if !strings.Contains(buf.String(), "Select at least one") {
|
||||
t.Error("expected 'Select at least one' helper text")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestErrCancelled(t *testing.T) {
|
||||
t.Run("NotNil", func(t *testing.T) {
|
||||
if errCancelled == nil {
|
||||
t.Error("errCancelled should not be nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Message", func(t *testing.T) {
|
||||
if errCancelled.Error() != "cancelled" {
|
||||
t.Errorf("expected 'cancelled', got %q", errCancelled.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Edge case tests for selector.go
|
||||
|
||||
// TestSelectState_SingleItem verifies that single item list works without crash.
|
||||
// List with only one item should still work.
|
||||
func TestSelectState_SingleItem(t *testing.T) {
|
||||
items := []selectItem{{Name: "only-one"}}
|
||||
|
||||
s := newSelectState(items)
|
||||
|
||||
// Down should do nothing (already at bottom)
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.selected != 0 {
|
||||
t.Errorf("down on single item: expected selected=0, got %d", s.selected)
|
||||
}
|
||||
|
||||
// Up should do nothing (already at top)
|
||||
s.handleInput(eventUp, 0)
|
||||
if s.selected != 0 {
|
||||
t.Errorf("up on single item: expected selected=0, got %d", s.selected)
|
||||
}
|
||||
|
||||
// Enter should select the only item
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if !done || result != "only-one" || err != nil {
|
||||
t.Errorf("enter on single item: expected (true, 'only-one', nil), got (%v, %q, %v)", done, result, err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSelectState_ExactlyMaxItems verifies boundary condition at maxDisplayedItems.
|
||||
// List with exactly maxDisplayedItems items should not scroll.
|
||||
func TestSelectState_ExactlyMaxItems(t *testing.T) {
|
||||
items := make([]selectItem, maxDisplayedItems)
|
||||
for i := range items {
|
||||
items[i] = selectItem{Name: string(rune('a' + i))}
|
||||
}
|
||||
|
||||
s := newSelectState(items)
|
||||
|
||||
// Move to last item
|
||||
for range maxDisplayedItems - 1 {
|
||||
s.handleInput(eventDown, 0)
|
||||
}
|
||||
|
||||
if s.selected != maxDisplayedItems-1 {
|
||||
t.Errorf("expected selected=%d, got %d", maxDisplayedItems-1, s.selected)
|
||||
}
|
||||
|
||||
// Should not scroll when exactly at max
|
||||
if s.scrollOffset != 0 {
|
||||
t.Errorf("expected scrollOffset=0 for exactly maxDisplayedItems, got %d", s.scrollOffset)
|
||||
}
|
||||
|
||||
// One more down should do nothing
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.selected != maxDisplayedItems-1 {
|
||||
t.Errorf("down at max: expected selected=%d, got %d", maxDisplayedItems-1, s.selected)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterItems_RegexSpecialChars verifies that filter is literal, not regex.
|
||||
// User typing "model.v1" shouldn't match "modelsv1".
|
||||
func TestFilterItems_RegexSpecialChars(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "model.v1"},
|
||||
{Name: "modelsv1"},
|
||||
{Name: "model-v1"},
|
||||
}
|
||||
|
||||
// Filter with dot should only match literal dot
|
||||
result := filterItems(items, "model.v1")
|
||||
if len(result) != 1 {
|
||||
t.Errorf("expected 1 exact match, got %d", len(result))
|
||||
}
|
||||
if len(result) > 0 && result[0].Name != "model.v1" {
|
||||
t.Errorf("expected 'model.v1', got %s", result[0].Name)
|
||||
}
|
||||
|
||||
// Other regex special chars should be literal too
|
||||
items2 := []selectItem{
|
||||
{Name: "test[0]"},
|
||||
{Name: "test0"},
|
||||
{Name: "test(1)"},
|
||||
}
|
||||
|
||||
result2 := filterItems(items2, "test[0]")
|
||||
if len(result2) != 1 || result2[0].Name != "test[0]" {
|
||||
t.Errorf("expected only 'test[0]', got %v", result2)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultiSelectState_DuplicateNames documents handling of duplicate item names.
|
||||
// itemIndex uses name as key - duplicates cause collision. This documents
|
||||
// the current behavior: the last index for a duplicate name is stored
|
||||
func TestMultiSelectState_DuplicateNames(t *testing.T) {
|
||||
// Duplicate names - this is an edge case that shouldn't happen in practice
|
||||
items := []selectItem{
|
||||
{Name: "duplicate"},
|
||||
{Name: "duplicate"},
|
||||
{Name: "unique"},
|
||||
}
|
||||
|
||||
s := newMultiSelectState(items, nil)
|
||||
|
||||
// DOCUMENTED BEHAVIOR: itemIndex maps name to LAST index
|
||||
// When there are duplicates, only the last occurrence's index is stored
|
||||
if s.itemIndex["duplicate"] != 1 {
|
||||
t.Errorf("itemIndex should map 'duplicate' to last index (1), got %d", s.itemIndex["duplicate"])
|
||||
}
|
||||
|
||||
// Toggle item at highlighted=0 (first "duplicate")
|
||||
// Due to name collision, toggleItem uses itemIndex["duplicate"] = 1
|
||||
// So it actually toggles the SECOND duplicate item, not the first
|
||||
s.toggleItem()
|
||||
|
||||
// This documents the potentially surprising behavior:
|
||||
// We toggled at highlighted=0, but itemIndex lookup returned 1
|
||||
if !s.checked[1] {
|
||||
t.Error("toggle should check index 1 (due to name collision in itemIndex)")
|
||||
}
|
||||
if s.checked[0] {
|
||||
t.Log("Note: index 0 is NOT checked, even though highlighted=0 (name collision behavior)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSelectState_FilterReducesBelowSelection verifies selection resets when filter reduces list.
|
||||
// Prevents index-out-of-bounds on next keystroke
|
||||
func TestSelectState_FilterReducesBelowSelection(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "apple"},
|
||||
{Name: "banana"},
|
||||
{Name: "cherry"},
|
||||
}
|
||||
|
||||
s := newSelectState(items)
|
||||
s.selected = 2 // Select "cherry"
|
||||
|
||||
// Type a filter that removes cherry from results
|
||||
s.handleInput(eventChar, 'a') // Filter to "a" - matches "apple" and "banana"
|
||||
|
||||
// Selection should reset to 0
|
||||
if s.selected != 0 {
|
||||
t.Errorf("expected selected=0 after filter, got %d", s.selected)
|
||||
}
|
||||
|
||||
filtered := s.filtered()
|
||||
if len(filtered) != 2 {
|
||||
t.Errorf("expected 2 filtered items, got %d", len(filtered))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterItems_UnicodeCharacters verifies filtering works with UTF-8.
|
||||
// Model names might contain unicode characters
|
||||
func TestFilterItems_UnicodeCharacters(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "llama-日本語"},
|
||||
{Name: "模型-chinese"},
|
||||
{Name: "émoji-🦙"},
|
||||
{Name: "regular-model"},
|
||||
}
|
||||
|
||||
t.Run("filter japanese", func(t *testing.T) {
|
||||
result := filterItems(items, "日本")
|
||||
if len(result) != 1 || result[0].Name != "llama-日本語" {
|
||||
t.Errorf("expected llama-日本語, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter chinese", func(t *testing.T) {
|
||||
result := filterItems(items, "模型")
|
||||
if len(result) != 1 || result[0].Name != "模型-chinese" {
|
||||
t.Errorf("expected 模型-chinese, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter emoji", func(t *testing.T) {
|
||||
result := filterItems(items, "🦙")
|
||||
if len(result) != 1 || result[0].Name != "émoji-🦙" {
|
||||
t.Errorf("expected émoji-🦙, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter accented char", func(t *testing.T) {
|
||||
result := filterItems(items, "émoji")
|
||||
if len(result) != 1 || result[0].Name != "émoji-🦙" {
|
||||
t.Errorf("expected émoji-🦙, got %v", result)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestMultiSelectState_FilterReducesBelowHighlight verifies highlight resets when filter reduces list.
|
||||
func TestMultiSelectState_FilterReducesBelowHighlight(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "apple"},
|
||||
{Name: "banana"},
|
||||
{Name: "cherry"},
|
||||
}
|
||||
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.highlighted = 2 // Highlight "cherry"
|
||||
|
||||
// Type a filter that removes cherry
|
||||
s.handleInput(eventChar, 'a')
|
||||
|
||||
if s.highlighted != 0 {
|
||||
t.Errorf("expected highlighted=0 after filter, got %d", s.highlighted)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultiSelectState_EmptyItems verifies handling of empty item list.
|
||||
// Empty list should be handled gracefully.
|
||||
func TestMultiSelectState_EmptyItems(t *testing.T) {
|
||||
s := newMultiSelectState([]selectItem{}, nil)
|
||||
|
||||
// Toggle should not panic on empty list
|
||||
s.toggleItem()
|
||||
|
||||
if s.selectedCount() != 0 {
|
||||
t.Errorf("expected 0 selected for empty list, got %d", s.selectedCount())
|
||||
}
|
||||
|
||||
// Render should handle empty list
|
||||
var buf bytes.Buffer
|
||||
lineCount := renderMultiSelect(&buf, "Select:", s)
|
||||
if lineCount == 0 {
|
||||
t.Error("renderMultiSelect should produce output even for empty list")
|
||||
}
|
||||
if !strings.Contains(buf.String(), "no matches") {
|
||||
t.Error("expected 'no matches' for empty list")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSelectState_RenderWithDescriptions verifies rendering items with descriptions.
|
||||
func TestSelectState_RenderWithDescriptions(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "item1", Description: "First item description"},
|
||||
{Name: "item2", Description: ""},
|
||||
{Name: "item3", Description: "Third item"},
|
||||
}
|
||||
|
||||
s := newSelectState(items)
|
||||
var buf bytes.Buffer
|
||||
renderSelect(&buf, "Select:", s)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "First item description") {
|
||||
t.Error("expected description to be rendered")
|
||||
}
|
||||
if !strings.Contains(output, "item2") {
|
||||
t.Error("expected item without description to be rendered")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user