diff --git a/cmd/cmd.go b/cmd/cmd.go index 5139c05cb..14e018a4e 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -43,7 +43,6 @@ import ( "github.com/ollama/ollama/runner" "github.com/ollama/ollama/server" "github.com/ollama/ollama/types/model" - "github.com/ollama/ollama/types/syncmap" "github.com/ollama/ollama/version" xcmd "github.com/ollama/ollama/x/cmd" "github.com/ollama/ollama/x/create" @@ -205,7 +204,6 @@ func CreateHandler(cmd *cobra.Command, args []string) error { if err != nil { return err } - spinner.Stop() req.Model = modelName quantize, _ := cmd.Flags().GetString("quantize") @@ -219,42 +217,29 @@ func CreateHandler(cmd *cobra.Command, args []string) error { } var g errgroup.Group - g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1)) + g.SetLimit(runtime.GOMAXPROCS(0)) + for blob, err := range createBlobs(req.Files, req.Adapters) { + if err != nil { + return err + } - files := syncmap.NewSyncMap[string, string]() - for f, digest := range req.Files { g.Go(func() error { - if _, err := createBlob(cmd, client, f, digest, p); err != nil { - return err - } - - // TODO: this is incorrect since the file might be in a subdirectory - // instead this should take the path relative to the model directory - // but the current implementation does not allow this - files.Store(filepath.Base(f), digest) - return nil + _, err := createBlob(cmd, client, blob.Abs, blob.Digest, p) + return err }) - } - adapters := syncmap.NewSyncMap[string, string]() - for f, digest := range req.Adapters { - g.Go(func() error { - if _, err := createBlob(cmd, client, f, digest, p); err != nil { - return err - } - - // TODO: same here - adapters.Store(filepath.Base(f), digest) - return nil - }) + if _, ok := req.Files[blob.Rel]; ok { + req.Files[blob.Rel] = blob.Digest + } else if _, ok := req.Adapters[blob.Rel]; ok { + req.Adapters[blob.Rel] = blob.Digest + } } if err := g.Wait(); err != nil { return err } - req.Files = files.Items() - req.Adapters = adapters.Items() + spinner.Stop() bars := make(map[string]*progress.Bar) fn := func(resp api.ProgressResponse) error { @@ -292,54 +277,6 @@ func CreateHandler(cmd *cobra.Command, args []string) error { return nil } -func createBlob(cmd *cobra.Command, client *api.Client, path string, digest string, p *progress.Progress) (string, error) { - realPath, err := filepath.EvalSymlinks(path) - if err != nil { - return "", err - } - - bin, err := os.Open(realPath) - if err != nil { - return "", err - } - defer bin.Close() - - // Get file info to retrieve the size - fileInfo, err := bin.Stat() - if err != nil { - return "", err - } - fileSize := fileInfo.Size() - - var pw progressWriter - status := fmt.Sprintf("copying file %s 0%%", digest) - spinner := progress.NewSpinner(status) - p.Add(status, spinner) - defer spinner.Stop() - - done := make(chan struct{}) - defer close(done) - - go func() { - ticker := time.NewTicker(60 * time.Millisecond) - defer ticker.Stop() - for { - select { - case <-ticker.C: - spinner.SetMessage(fmt.Sprintf("copying file %s %d%%", digest, int(100*pw.n.Load()/fileSize))) - case <-done: - spinner.SetMessage(fmt.Sprintf("copying file %s 100%%", digest)) - return - } - } - }() - - if err := client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil { - return "", err - } - return digest, nil -} - type progressWriter struct { n atomic.Int64 } diff --git a/cmd/create.go b/cmd/create.go new file mode 100644 index 000000000..ed90b7035 --- /dev/null +++ b/cmd/create.go @@ -0,0 +1,103 @@ +package cmd + +import ( + "crypto/sha256" + "fmt" + "io" + "iter" + "os" + "path/filepath" + "strings" + "time" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/progress" + "github.com/spf13/cobra" +) + +type blob struct { + Rel, Abs, Digest string +} + +func createBlob(cmd *cobra.Command, client *api.Client, path string, digest string, p *progress.Progress) (string, error) { + realPath, err := filepath.EvalSymlinks(path) + if err != nil { + return "", err + } + + bin, err := os.Open(realPath) + if err != nil { + return "", err + } + defer bin.Close() + + // Get file info to retrieve the size + fileInfo, err := bin.Stat() + if err != nil { + return "", err + } + fileSize := fileInfo.Size() + + var pw progressWriter + status := fmt.Sprintf("copying file %s 0%%", digest) + spinner := progress.NewSpinner(status) + p.Add(status, spinner) + defer spinner.Stop() + + done := make(chan struct{}) + defer close(done) + + go func() { + ticker := time.NewTicker(60 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-ticker.C: + spinner.SetMessage(fmt.Sprintf("copying file %s %d%%", digest, int(100*pw.n.Load()/fileSize))) + case <-done: + spinner.SetMessage(fmt.Sprintf("copying file %s 100%%", digest)) + return + } + } + }() + + if err := client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil { + return "", err + } + return digest, nil +} + +func createBlobs(mappings ...map[string]string) iter.Seq2[blob, error] { + return func(yield func(blob, error) bool) { + for _, mapping := range mappings { + for rel, abs := range mapping { + if abs, ok := strings.CutPrefix(abs, "abs:"); ok { + f, err := os.Open(abs) + if err != nil { + yield(blob{}, err) + return + } + + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + yield(blob{}, err) + return + } + + if err := f.Close(); err != nil { + yield(blob{}, err) + return + } + + if !yield(blob{ + Rel: rel, + Abs: abs, + Digest: fmt.Sprintf("sha256:%x", h.Sum(nil)), + }, nil) { + return + } + } + } + } + } +} diff --git a/parser/parser.go b/parser/parser.go index 5ef918bf2..ed5b2c646 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -3,22 +3,20 @@ package parser import ( "bufio" "bytes" - "crypto/sha256" "errors" "fmt" "io" + "io/fs" + "maps" "net/http" "os" "os/user" "path/filepath" - "runtime" "slices" "strconv" "strings" - "sync" "golang.org/x/mod/semver" - "golang.org/x/sync/errgroup" "golang.org/x/text/encoding/unicode" "golang.org/x/text/transform" @@ -54,7 +52,10 @@ var deprecatedParameters = []string{ // CreateRequest creates a new *api.CreateRequest from an existing Modelfile func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error) { - req := &api.CreateRequest{} + req := &api.CreateRequest{ + Files: make(map[string]string), + Adapters: make(map[string]string), + } var messages []api.Message var licenses []string @@ -63,12 +64,7 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error) for _, c := range f.Commands { switch c.Name { case "model": - path, err := expandPath(c.Args, relativeDir) - if err != nil { - return nil, err - } - - digestMap, err := fileDigestMap(path) + files, err := filesMap(c.Args, relativeDir) if errors.Is(err, os.ErrNotExist) { req.From = c.Args continue @@ -76,25 +72,14 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error) return nil, err } - if req.Files == nil { - req.Files = digestMap - } else { - for k, v := range digestMap { - req.Files[k] = v - } - } + maps.Copy(req.Files, files) case "adapter": - path, err := expandPath(c.Args, relativeDir) + files, err := filesMap(c.Args, relativeDir) if err != nil { return nil, err } - digestMap, err := fileDigestMap(path) - if err != nil { - return nil, err - } - - req.Adapters = digestMap + maps.Copy(req.Adapters, files) case "template": req.Template = c.Args case "system": @@ -154,106 +139,66 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error) return req, nil } -func fileDigestMap(path string) (map[string]string, error) { - fl := make(map[string]string) +func filesMap(args, base string) (map[string]string, error) { + path, err := expandPath(args, base) + if err != nil { + return nil, err + } fi, err := os.Stat(path) if err != nil { return nil, err } - var files []string - if fi.IsDir() { - fs, err := filesForModel(path) - if err != nil { - return nil, err - } - - for _, f := range fs { - f, err := filepath.EvalSymlinks(f) - if err != nil { - return nil, err - } - - rel, err := filepath.Rel(path, f) - if err != nil { - return nil, err - } - - if !filepath.IsLocal(rel) { - return nil, fmt.Errorf("insecure path: %s", rel) - } - - files = append(files, f) - } - } else { - files = []string{path} + mapping := make(map[string]string) + if !fi.IsDir() { + return map[string]string{ + filepath.Base(path): "abs:" + path, + }, nil } - var mu sync.Mutex - var g errgroup.Group - g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1)) - for _, f := range files { - g.Go(func() error { - digest, err := digestForFile(f) - if err != nil { - return err - } - - mu.Lock() - defer mu.Unlock() - fl[f] = digest - return nil - }) + root, err := os.OpenRoot(path) + if err != nil { + return nil, err } + defer root.Close() - if err := g.Wait(); err != nil { + files, err := filesForModel(root) + if err != nil { return nil, err } - return fl, nil + for _, file := range files { + // create a temporary mapping from relative path to absolute path + mapping[file] = "abs:" + filepath.Join(root.Name(), file) + } + + return mapping, nil } -func digestForFile(filename string) (string, error) { - filepath, err := filepath.EvalSymlinks(filename) - if err != nil { - return "", err - } - - bin, err := os.Open(filepath) - if err != nil { - return "", err - } - defer bin.Close() - - hash := sha256.New() - if _, err := io.Copy(hash, bin); err != nil { - return "", err - } - return fmt.Sprintf("sha256:%x", hash.Sum(nil)), nil -} - -func filesForModel(path string) ([]string, error) { +func filesForModel(root *os.Root) ([]string, error) { detectContentType := func(path string) (string, error) { - f, err := os.Open(path) + f, err := root.Open(path) if err != nil { return "", err } defer f.Close() - var b bytes.Buffer - b.Grow(512) - - if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) { + bts := make([]byte, 512) + n, err := io.ReadFull(f, bts) + if errors.Is(err, io.ErrUnexpectedEOF) { + // short read, use what we have + bts = bts[:n] + } else if err != nil { return "", err } - contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";") + contentType, _, _ := strings.Cut(http.DetectContentType(bts), ";") return contentType, nil } glob := func(pattern, contentType string) ([]string, error) { - matches, err := filepath.Glob(pattern) + matches, err := fs.Glob(root.FS(), pattern) if err != nil { return nil, err } @@ -262,7 +207,7 @@ func filesForModel(path string) ([]string, error) { if ct, err := detectContentType(match); err != nil { return nil, err } else if len(contentType) > 0 && ct != contentType { - return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match) + return nil, fmt.Errorf("invalid content type: expected %s for %s, got %s", ct, match, contentType) } } @@ -271,25 +216,25 @@ func filesForModel(path string) ([]string, error) { var files []string // some safetensors files do not properly match "application/octet-stream", so skip checking their contentType - if st, _ := glob(filepath.Join(path, "model*.safetensors"), ""); len(st) > 0 { + if st, _ := glob("model*.safetensors", ""); len(st) > 0 { // safetensors files might be unresolved git lfs references; skip if they are // covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors files = append(files, st...) - } else if st, _ := glob(filepath.Join(path, "consolidated*.safetensors"), ""); len(st) > 0 { + } else if st, _ := glob("consolidated*.safetensors", ""); len(st) > 0 { // covers consolidated.safetensors files = append(files, st...) - } else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 { + } else if pt, _ := glob("pytorch_model*.bin", "application/zip"); len(pt) > 0 { // pytorch files might also be unresolved git lfs references; skip if they are // covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin files = append(files, pt...) - } else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/zip"); len(pt) > 0 { + } else if pt, _ := glob("consolidated*.pth", "application/zip"); len(pt) > 0 { // pytorch files might also be unresolved git lfs references; skip if they are // covers consolidated.x.pth, consolidated.pth files = append(files, pt...) - } else if gg, _ := glob(filepath.Join(path, "*.gguf"), "application/octet-stream"); len(gg) > 0 { + } else if gg, _ := glob("*.gguf", "application/octet-stream"); len(gg) > 0 { // covers gguf files ending in .gguf files = append(files, gg...) - } else if gg, _ := glob(filepath.Join(path, "*.bin"), "application/octet-stream"); len(gg) > 0 { + } else if gg, _ := glob("*.bin", "application/octet-stream"); len(gg) > 0 { // covers gguf files ending in .bin files = append(files, gg...) } else { @@ -297,7 +242,7 @@ func filesForModel(path string) ([]string, error) { } // add configuration files, json files are detected as text/plain - js, err := glob(filepath.Join(path, "*.json"), "text/plain") + js, err := glob("*.json", "text/plain") if err != nil { return nil, err } @@ -305,7 +250,7 @@ func filesForModel(path string) ([]string, error) { // bert models require a nested config.json // TODO(mxyng): merge this with the glob above - js, err = glob(filepath.Join(path, "**/*.json"), "text/plain") + js, err = glob("**/*.json", "text/plain") if err != nil { return nil, err } @@ -313,9 +258,9 @@ func filesForModel(path string) ([]string, error) { // add tokenizer.model if it exists (tokenizer.json is automatically picked up by the previous glob) // tokenizer.model might be a unresolved git lfs reference; error if it is - if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 { + if tks, _ := glob("tokenizer.model", "application/octet-stream"); len(tks) > 0 { files = append(files, tks...) - } else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 { + } else if tks, _ := glob("**/tokenizer.model", "text/plain"); len(tks) > 0 { // some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B) files = append(files, tks...) } diff --git a/parser/parser_test.go b/parser/parser_test.go index 4dcfed0cb..a31fb198a 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -2,7 +2,6 @@ package parser import ( "bytes" - "crypto/sha256" "encoding/binary" "errors" "fmt" @@ -15,6 +14,7 @@ import ( "unicode/utf16" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/text/encoding" @@ -775,25 +775,13 @@ MESSAGE assistant Hi! How are you? t.Error(err) } - if diff := cmp.Diff(actual, c.expected); diff != "" { + if diff := cmp.Diff(actual, c.expected, cmpopts.EquateEmpty()); diff != "" { t.Errorf("mismatch (-got +want):\n%s", diff) } } } -func getSHA256Digest(t *testing.T, r io.Reader) (string, int64) { - t.Helper() - - h := sha256.New() - n, err := io.Copy(h, r) - if err != nil { - t.Fatal(err) - } - - return fmt.Sprintf("sha256:%x", h.Sum(nil)), n -} - -func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string, string) { +func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) string { t.Helper() f, err := os.CreateTemp(t.TempDir(), "testbin.*.gguf") @@ -808,19 +796,12 @@ func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string, if err := ggml.WriteGGUF(f, base, ti); err != nil { t.Fatal(err) } - // Calculate sha256 of file - if _, err := f.Seek(0, 0); err != nil { - t.Fatal(err) - } - - digest, _ := getSHA256Digest(t, f) - - return f.Name(), digest + return f.Name() } func TestCreateRequestFiles(t *testing.T) { - n1, d1 := createBinFile(t, nil, nil) - n2, d2 := createBinFile(t, map[string]any{"foo": "bar"}, nil) + n1 := createBinFile(t, nil, nil) + n2 := createBinFile(t, map[string]any{"foo": "bar"}, nil) cases := []struct { input string @@ -828,11 +809,20 @@ func TestCreateRequestFiles(t *testing.T) { }{ { fmt.Sprintf("FROM %s", n1), - &api.CreateRequest{Files: map[string]string{n1: d1}}, + &api.CreateRequest{ + Files: map[string]string{ + filepath.Base(n1): "abs:" + n1, + }, + }, }, { fmt.Sprintf("FROM %s\nFROM %s", n1, n2), - &api.CreateRequest{Files: map[string]string{n1: d1, n2: d2}}, + &api.CreateRequest{ + Files: map[string]string{ + filepath.Base(n1): "abs:" + n1, + filepath.Base(n2): "abs:" + n2, + }, + }, }, } @@ -852,7 +842,7 @@ func TestCreateRequestFiles(t *testing.T) { t.Error(err) } - if diff := cmp.Diff(actual, c.expected); diff != "" { + if diff := cmp.Diff(actual, c.expected, cmpopts.EquateEmpty()); diff != "" { t.Errorf("mismatch (-got +want):\n%s", diff) } } @@ -860,15 +850,15 @@ func TestCreateRequestFiles(t *testing.T) { func TestFilesForModel(t *testing.T) { tests := []struct { - name string - setup func(string) error - wantFiles []string - wantErr bool - expectErrType error + name string + setup func(*testing.T, *os.Root) + want []string + wantErr error }{ { name: "safetensors model files", - setup: func(dir string) error { + setup: func(t *testing.T, root *os.Root) { + t.Helper() files := []string{ "model-00001-of-00002.safetensors", "model-00002-of-00002.safetensors", @@ -876,13 +866,12 @@ func TestFilesForModel(t *testing.T) { "tokenizer.json", } for _, file := range files { - if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil { - return err + if err := root.WriteFile(file, []byte("test content"), 0o644); err != nil { + t.Fatal(err) } } - return nil }, - wantFiles: []string{ + want: []string{ "model-00001-of-00002.safetensors", "model-00002-of-00002.safetensors", "config.json", @@ -891,7 +880,7 @@ func TestFilesForModel(t *testing.T) { }, { name: "safetensors with both tokenizer.json and tokenizer.model", - setup: func(dir string) error { + setup: func(t *testing.T, root *os.Root) { // Create binary content for tokenizer.model (application/octet-stream) binaryContent := make([]byte, 512) for i := range binaryContent { @@ -903,17 +892,16 @@ func TestFilesForModel(t *testing.T) { "tokenizer.json", } for _, file := range files { - if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil { - return err + if err := root.WriteFile(file, []byte("test content"), 0o644); err != nil { + t.Fatal(err) } } // Write tokenizer.model as binary - if err := os.WriteFile(filepath.Join(dir, "tokenizer.model"), binaryContent, 0o644); err != nil { - return err + if err := root.WriteFile("tokenizer.model", binaryContent, 0o644); err != nil { + t.Fatal(err) } - return nil }, - wantFiles: []string{ + want: []string{ "model-00001-of-00001.safetensors", "config.json", "tokenizer.json", @@ -922,46 +910,44 @@ func TestFilesForModel(t *testing.T) { }, { name: "safetensors with consolidated files - prefers model files", - setup: func(dir string) error { + setup: func(t *testing.T, root *os.Root) { files := []string{ "model-00001-of-00001.safetensors", "consolidated.safetensors", "config.json", } for _, file := range files { - if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil { - return err + if err := root.WriteFile(file, []byte("test content"), 0o644); err != nil { + t.Fatal(err) } } - return nil }, - wantFiles: []string{ + want: []string{ "model-00001-of-00001.safetensors", // consolidated files should be excluded "config.json", }, }, { name: "safetensors without model-.safetensors files - uses consolidated", - setup: func(dir string) error { + setup: func(t *testing.T, root *os.Root) { files := []string{ "consolidated.safetensors", "config.json", } for _, file := range files { - if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil { - return err + if err := root.WriteFile(file, []byte("test content"), 0o644); err != nil { + t.Fatal(err) } } - return nil }, - wantFiles: []string{ + want: []string{ "consolidated.safetensors", "config.json", }, }, { name: "pytorch model files", - setup: func(dir string) error { + setup: func(t *testing.T, root *os.Root) { // Create a file that will be detected as application/zip zipHeader := []byte{0x50, 0x4B, 0x03, 0x04} // PK zip header files := []string{ @@ -974,13 +960,12 @@ func TestFilesForModel(t *testing.T) { if file == "config.json" { content = []byte(`{"config": true}`) } - if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil { - return err + if err := root.WriteFile(file, content, 0o644); err != nil { + t.Fatal(err) } } - return nil }, - wantFiles: []string{ + want: []string{ "pytorch_model-00001-of-00002.bin", "pytorch_model-00002-of-00002.bin", "config.json", @@ -988,7 +973,7 @@ func TestFilesForModel(t *testing.T) { }, { name: "consolidated pth files", - setup: func(dir string) error { + setup: func(t *testing.T, root *os.Root) { zipHeader := []byte{0x50, 0x4B, 0x03, 0x04} files := []string{ "consolidated.00.pth", @@ -1000,13 +985,12 @@ func TestFilesForModel(t *testing.T) { if file == "config.json" { content = []byte(`{"config": true}`) } - if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil { - return err + if err := root.WriteFile(file, content, 0o644); err != nil { + t.Fatal(err) } } - return nil }, - wantFiles: []string{ + want: []string{ "consolidated.00.pth", "consolidated.01.pth", "config.json", @@ -1014,7 +998,7 @@ func TestFilesForModel(t *testing.T) { }, { name: "gguf files", - setup: func(dir string) error { + setup: func(t *testing.T, root *os.Root) { // Create binary content that will be detected as application/octet-stream binaryContent := make([]byte, 512) for i := range binaryContent { @@ -1029,20 +1013,19 @@ func TestFilesForModel(t *testing.T) { if file == "config.json" { content = []byte(`{"config": true}`) } - if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil { - return err + if err := root.WriteFile(file, content, 0o644); err != nil { + t.Fatal(err) } } - return nil }, - wantFiles: []string{ + want: []string{ "model.gguf", "config.json", }, }, { name: "bin files as gguf", - setup: func(dir string) error { + setup: func(t *testing.T, root *os.Root) { binaryContent := make([]byte, 512) for i := range binaryContent { binaryContent[i] = byte(i % 256) @@ -1056,35 +1039,32 @@ func TestFilesForModel(t *testing.T) { if file == "config.json" { content = []byte(`{"config": true}`) } - if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil { - return err + if err := root.WriteFile(file, content, 0o644); err != nil { + t.Fatal(err) } } - return nil }, - wantFiles: []string{ + want: []string{ "model.bin", "config.json", }, }, { name: "no model files found", - setup: func(dir string) error { + setup: func(t *testing.T, root *os.Root) { // Only create non-model files files := []string{"README.md", "config.json"} for _, file := range files { - if err := os.WriteFile(filepath.Join(dir, file), []byte("content"), 0o644); err != nil { - return err + if err := root.WriteFile(file, []byte("content"), 0o644); err != nil { + t.Fatal(err) } } - return nil }, - wantErr: true, - expectErrType: ErrModelNotFound, + wantErr: ErrModelNotFound, }, { name: "invalid content type for pytorch model", - setup: func(dir string) error { + setup: func(t *testing.T, root *os.Root) { // Create pytorch model file with wrong content type (text instead of zip) files := []string{ "pytorch_model.bin", @@ -1092,68 +1072,32 @@ func TestFilesForModel(t *testing.T) { } for _, file := range files { content := []byte("plain text content") - if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil { - return err + if err := root.WriteFile(file, content, 0o644); err != nil { + t.Fatal(err) } } - return nil }, - wantErr: true, + wantErr: ErrModelNotFound, }, } - tmpDir := t.TempDir() - for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - testDir := filepath.Join(tmpDir, tt.name) - if err := os.MkdirAll(testDir, 0o755); err != nil { - t.Fatalf("Failed to create test directory: %v", err) - } - - if err := tt.setup(testDir); err != nil { - t.Fatalf("Setup failed: %v", err) - } - - files, err := filesForModel(testDir) - - if tt.wantErr { - if err == nil { - t.Error("Expected error, but got none") - } - if tt.expectErrType != nil && err != tt.expectErrType { - t.Errorf("Expected error type %v, got %v", tt.expectErrType, err) - } - return - } - + root, err := os.OpenRoot(t.TempDir()) if err != nil { - t.Errorf("Unexpected error: %v", err) - return + t.Fatalf("Failed to open root: %v", err) + } + defer root.Close() + + tt.setup(t, root) + + files, err := filesForModel(root) + if !errors.Is(err, tt.wantErr) { + t.Fatalf("want %v error, got %v", tt.wantErr, err) } - var relativeFiles []string - for _, file := range files { - rel, err := filepath.Rel(testDir, file) - if err != nil { - t.Fatalf("Failed to get relative path: %v", err) - } - relativeFiles = append(relativeFiles, rel) - } - - if len(relativeFiles) != len(tt.wantFiles) { - t.Errorf("Expected %d files, got %d: %v", len(tt.wantFiles), len(relativeFiles), relativeFiles) - } - - fileSet := make(map[string]bool) - for _, file := range relativeFiles { - fileSet[file] = true - } - - for _, wantFile := range tt.wantFiles { - if !fileSet[wantFile] { - t.Errorf("Missing expected file: %s", wantFile) - } + if diff := cmp.Diff(tt.want, files); diff != "" { + t.Errorf("filesForModel() mismatch (-want +got):\n%s", diff) } }) } diff --git a/types/syncmap/syncmap.go b/types/syncmap/syncmap.go deleted file mode 100644 index ff21cd999..000000000 --- a/types/syncmap/syncmap.go +++ /dev/null @@ -1,38 +0,0 @@ -package syncmap - -import ( - "maps" - "sync" -) - -// SyncMap is a simple, generic thread-safe map implementation. -type SyncMap[K comparable, V any] struct { - mu sync.RWMutex - m map[K]V -} - -func NewSyncMap[K comparable, V any]() *SyncMap[K, V] { - return &SyncMap[K, V]{ - m: make(map[K]V), - } -} - -func (s *SyncMap[K, V]) Load(key K) (V, bool) { - s.mu.RLock() - defer s.mu.RUnlock() - val, ok := s.m[key] - return val, ok -} - -func (s *SyncMap[K, V]) Store(key K, value V) { - s.mu.Lock() - defer s.mu.Unlock() - s.m[key] = value -} - -func (s *SyncMap[K, V]) Items() map[K]V { - s.mu.RLock() - defer s.mu.RUnlock() - // shallow copy map items - return maps.Clone(s.m) -}