diff --git a/cmd/cmd.go b/cmd/cmd.go index 09940bda5..a5d46f90d 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -46,8 +46,9 @@ import ( "github.com/ollama/ollama/types/syncmap" "github.com/ollama/ollama/version" xcmd "github.com/ollama/ollama/x/cmd" + "github.com/ollama/ollama/x/create" + xcreateclient "github.com/ollama/ollama/x/create/client" "github.com/ollama/ollama/x/imagegen" - imagegenclient "github.com/ollama/ollama/x/imagegen/client" ) const ConnectInstructions = "To sign in, navigate to:\n %s\n\n" @@ -93,15 +94,87 @@ func CreateHandler(cmd *cobra.Command, args []string) error { p := progress.NewProgress(os.Stderr) defer p.Stop() + // Validate model name early to fail fast + modelName := args[0] + name := model.ParseName(modelName) + if !name.IsValid() { + return fmt.Errorf("invalid model name: %s", modelName) + } + + // Check for --experimental flag for safetensors model creation + experimental, _ := cmd.Flags().GetBool("experimental") + if experimental { + // Get Modelfile content - either from -f flag or default to "FROM ." + var reader io.Reader + filename, err := getModelfileName(cmd) + if os.IsNotExist(err) || filename == "" { + // No Modelfile specified or found - use default + reader = strings.NewReader("FROM .\n") + } else if err != nil { + return err + } else { + f, err := os.Open(filename) + if err != nil { + return err + } + defer f.Close() + reader = f + } + + // Parse the Modelfile + modelfile, err := parser.ParseFile(reader) + if err != nil { + return fmt.Errorf("failed to parse Modelfile: %w", err) + } + + // Extract FROM path and configuration + var modelDir string + mfConfig := &xcreateclient.ModelfileConfig{} + + for _, cmd := range modelfile.Commands { + switch cmd.Name { + case "model": + modelDir = cmd.Args + case "template": + mfConfig.Template = cmd.Args + case "system": + mfConfig.System = cmd.Args + case "license": + mfConfig.License = cmd.Args + } + } + + if modelDir == "" { + modelDir = "." + } + + // Resolve relative paths based on Modelfile location + if !filepath.IsAbs(modelDir) && filename != "" { + modelDir = filepath.Join(filepath.Dir(filename), modelDir) + } + + quantize, _ := cmd.Flags().GetString("quantize") + return xcreateclient.CreateModel(xcreateclient.CreateOptions{ + ModelName: modelName, + ModelDir: modelDir, + Quantize: quantize, + Modelfile: mfConfig, + }, p) + } + var reader io.Reader filename, err := getModelfileName(cmd) if os.IsNotExist(err) { if filename == "" { // No Modelfile found - check if current directory is an image gen model - if imagegen.IsTensorModelDir(".") { + if create.IsTensorModelDir(".") { quantize, _ := cmd.Flags().GetString("quantize") - return imagegenclient.CreateModel(args[0], ".", quantize, p) + return xcreateclient.CreateModel(xcreateclient.CreateOptions{ + ModelName: modelName, + ModelDir: ".", + Quantize: quantize, + }, p) } reader = strings.NewReader("FROM .\n") } else { @@ -134,7 +207,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { } spinner.Stop() - req.Model = args[0] + req.Model = modelName quantize, _ := cmd.Flags().GetString("quantize") if quantize != "" { req.Quantize = quantize @@ -1742,15 +1815,22 @@ func NewCLI() *cobra.Command { rootCmd.Flags().BoolP("version", "v", false, "Show version information") createCmd := &cobra.Command{ - Use: "create MODEL", - Short: "Create a model", - Args: cobra.ExactArgs(1), - PreRunE: checkServerHeartbeat, - RunE: CreateHandler, + Use: "create MODEL", + Short: "Create a model", + Args: cobra.ExactArgs(1), + PreRunE: func(cmd *cobra.Command, args []string) error { + // Skip server check for experimental mode (writes directly to disk) + if experimental, _ := cmd.Flags().GetBool("experimental"); experimental { + return nil + } + return checkServerHeartbeat(cmd, args) + }, + RunE: CreateHandler, } createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\")") createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_K_M)") + createCmd.Flags().Bool("experimental", false, "Enable experimental safetensors model creation") showCmd := &cobra.Command{ Use: "show MODEL", diff --git a/server/routes.go b/server/routes.go index 75d8aa610..e90b885e0 100644 --- a/server/routes.go +++ b/server/routes.go @@ -51,6 +51,7 @@ import ( "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" "github.com/ollama/ollama/x/imagegen" + xserver "github.com/ollama/ollama/x/server" ) const signinURLStr = "https://ollama.com/connect?name=%s&key=%s" @@ -1103,6 +1104,22 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { } } + // For safetensors LLM models (experimental), populate details from config.json + if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") { + if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil { + if arch, ok := info["general.architecture"].(string); ok && arch != "" { + modelDetails.Family = arch + } + if paramCount, ok := info["general.parameter_count"].(int64); ok && paramCount > 0 { + modelDetails.ParameterSize = format.HumanNumber(uint64(paramCount)) + } + } + // Get torch_dtype directly from config.json for quantization level + if dtype, err := xserver.GetSafetensorsDtype(name.String()); err == nil && dtype != "" { + modelDetails.QuantizationLevel = dtype + } + } + if req.System != "" { m.System = req.System } @@ -1186,6 +1203,26 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { } if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) { + // Populate tensor info if verbose + if req.Verbose { + if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil { + resp.Tensors = tensors + } + } + return resp, nil + } + + // For safetensors LLM models (experimental), populate ModelInfo from config.json + if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") { + if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil { + resp.ModelInfo = info + } + // Populate tensor info if verbose + if req.Verbose { + if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil { + resp.Tensors = tensors + } + } return resp, nil } diff --git a/x/create/client/create.go b/x/create/client/create.go new file mode 100644 index 000000000..7729c6e5f --- /dev/null +++ b/x/create/client/create.go @@ -0,0 +1,282 @@ +// Package client provides client-side model creation for safetensors-based models. +// +// This package is in x/ because the safetensors model storage format is under development. +// It also exists to break an import cycle: server imports x/create, so x/create +// cannot import server. This sub-package can import server because server doesn't +// import it. +package client + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + + "github.com/ollama/ollama/progress" + "github.com/ollama/ollama/server" + "github.com/ollama/ollama/types/model" + "github.com/ollama/ollama/x/create" +) + +// MinOllamaVersion is the minimum Ollama version required for safetensors models. +const MinOllamaVersion = "0.14.0" + +// ModelfileConfig holds configuration extracted from a Modelfile. +type ModelfileConfig struct { + Template string + System string + License string +} + +// CreateOptions holds all options for model creation. +type CreateOptions struct { + ModelName string + ModelDir string + Quantize string // "fp8" for quantization + Modelfile *ModelfileConfig // template/system/license from Modelfile +} + +// CreateModel imports a model from a local directory. +// This creates blobs and manifest directly on disk, bypassing the HTTP API. +// Automatically detects model type (safetensors LLM vs image gen) and routes accordingly. +func CreateModel(opts CreateOptions, p *progress.Progress) error { + // Detect model type + isSafetensors := create.IsSafetensorsModelDir(opts.ModelDir) + isImageGen := create.IsTensorModelDir(opts.ModelDir) + + if !isSafetensors && !isImageGen { + return fmt.Errorf("%s is not a supported model directory (needs config.json + *.safetensors or model_index.json)", opts.ModelDir) + } + + // Determine model type settings + var modelType, spinnerKey string + var capabilities []string + if isSafetensors { + modelType = "safetensors model" + spinnerKey = "create" + capabilities = []string{"completion"} + } else { + modelType = "image generation model" + spinnerKey = "imagegen" + capabilities = []string{"image"} + } + + // Set up progress spinner + statusMsg := "importing " + modelType + spinner := progress.NewSpinner(statusMsg) + p.Add(spinnerKey, spinner) + + progressFn := func(msg string) { + spinner.Stop() + statusMsg = msg + spinner = progress.NewSpinner(statusMsg) + p.Add(spinnerKey, spinner) + } + + // Create the model using shared callbacks + var err error + if isSafetensors { + err = create.CreateSafetensorsModel( + opts.ModelName, opts.ModelDir, opts.Quantize, + newLayerCreator(), newTensorLayerCreator(), + newManifestWriter(opts, capabilities), + progressFn, + ) + } else { + err = create.CreateImageGenModel( + opts.ModelName, opts.ModelDir, opts.Quantize, + newLayerCreator(), newTensorLayerCreator(), + newManifestWriter(opts, capabilities), + progressFn, + ) + } + + spinner.Stop() + if err != nil { + return err + } + + fmt.Printf("Created %s '%s'\n", modelType, opts.ModelName) + return nil +} + +// newLayerCreator returns a LayerCreator callback for creating config/JSON layers. +func newLayerCreator() create.LayerCreator { + return func(r io.Reader, mediaType, name string) (create.LayerInfo, error) { + layer, err := server.NewLayer(r, mediaType) + if err != nil { + return create.LayerInfo{}, err + } + + return create.LayerInfo{ + Digest: layer.Digest, + Size: layer.Size, + MediaType: layer.MediaType, + Name: name, + }, nil + } +} + +// newTensorLayerCreator returns a QuantizingTensorLayerCreator callback for creating tensor layers. +// When quantize is non-empty, returns multiple layers (weight + scales + optional qbias). +func newTensorLayerCreator() create.QuantizingTensorLayerCreator { + return func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]create.LayerInfo, error) { + if quantize != "" { + return createQuantizedLayers(r, name, dtype, shape, quantize) + } + return createUnquantizedLayer(r, name) + } +} + +// createQuantizedLayers quantizes a tensor and returns the resulting layers. +func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quantize string) ([]create.LayerInfo, error) { + if !QuantizeSupported() { + return nil, fmt.Errorf("quantization requires MLX support") + } + + // Quantize the tensor + qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape, quantize) + if err != nil { + return nil, fmt.Errorf("failed to quantize %s: %w", name, err) + } + + // Create layer for quantized weight + weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor) + if err != nil { + return nil, err + } + + // Create layer for scales + scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor) + if err != nil { + return nil, err + } + + layers := []create.LayerInfo{ + { + Digest: weightLayer.Digest, + Size: weightLayer.Size, + MediaType: weightLayer.MediaType, + Name: name, + }, + { + Digest: scalesLayer.Digest, + Size: scalesLayer.Size, + MediaType: scalesLayer.MediaType, + Name: name + "_scale", + }, + } + + // Add qbiases layer if present (affine mode) + if qbiasData != nil { + qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor) + if err != nil { + return nil, err + } + layers = append(layers, create.LayerInfo{ + Digest: qbiasLayer.Digest, + Size: qbiasLayer.Size, + MediaType: qbiasLayer.MediaType, + Name: name + "_qbias", + }) + } + + return layers, nil +} + +// createUnquantizedLayer creates a single tensor layer without quantization. +func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error) { + layer, err := server.NewLayer(r, server.MediaTypeImageTensor) + if err != nil { + return nil, err + } + + return []create.LayerInfo{ + { + Digest: layer.Digest, + Size: layer.Size, + MediaType: layer.MediaType, + Name: name, + }, + }, nil +} + +// newManifestWriter returns a ManifestWriter callback for writing the model manifest. +func newManifestWriter(opts CreateOptions, capabilities []string) create.ManifestWriter { + return func(modelName string, config create.LayerInfo, layers []create.LayerInfo) error { + name := model.ParseName(modelName) + if !name.IsValid() { + return fmt.Errorf("invalid model name: %s", modelName) + } + + // Create config blob with version requirement + configData := model.ConfigV2{ + ModelFormat: "safetensors", + Capabilities: capabilities, + Requires: MinOllamaVersion, + } + configJSON, err := json.Marshal(configData) + if err != nil { + return fmt.Errorf("failed to marshal config: %w", err) + } + + // Create config layer blob + configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json") + if err != nil { + return fmt.Errorf("failed to create config layer: %w", err) + } + + // Convert LayerInfo to server.Layer + serverLayers := make([]server.Layer, 0, len(layers)) + for _, l := range layers { + serverLayers = append(serverLayers, server.Layer{ + MediaType: l.MediaType, + Digest: l.Digest, + Size: l.Size, + Name: l.Name, + }) + } + + // Add Modelfile layers if present + if opts.Modelfile != nil { + modelfileLayers, err := createModelfileLayers(opts.Modelfile) + if err != nil { + return err + } + serverLayers = append(serverLayers, modelfileLayers...) + } + + return server.WriteManifest(name, configLayer, serverLayers) + } +} + +// createModelfileLayers creates layers for template, system, and license from Modelfile config. +func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) { + var layers []server.Layer + + if mf.Template != "" { + layer, err := server.NewLayer(bytes.NewReader([]byte(mf.Template)), "application/vnd.ollama.image.template") + if err != nil { + return nil, fmt.Errorf("failed to create template layer: %w", err) + } + layers = append(layers, layer) + } + + if mf.System != "" { + layer, err := server.NewLayer(bytes.NewReader([]byte(mf.System)), "application/vnd.ollama.image.system") + if err != nil { + return nil, fmt.Errorf("failed to create system layer: %w", err) + } + layers = append(layers, layer) + } + + if mf.License != "" { + layer, err := server.NewLayer(bytes.NewReader([]byte(mf.License)), "application/vnd.ollama.image.license") + if err != nil { + return nil, fmt.Errorf("failed to create license layer: %w", err) + } + layers = append(layers, layer) + } + + return layers, nil +} diff --git a/x/create/client/create_test.go b/x/create/client/create_test.go new file mode 100644 index 000000000..b41807279 --- /dev/null +++ b/x/create/client/create_test.go @@ -0,0 +1,146 @@ +package client + +import ( + "testing" +) + +func TestModelfileConfig(t *testing.T) { + // Test that ModelfileConfig struct works as expected + config := &ModelfileConfig{ + Template: "{{ .Prompt }}", + System: "You are a helpful assistant.", + License: "MIT", + } + + if config.Template != "{{ .Prompt }}" { + t.Errorf("Template = %q, want %q", config.Template, "{{ .Prompt }}") + } + if config.System != "You are a helpful assistant." { + t.Errorf("System = %q, want %q", config.System, "You are a helpful assistant.") + } + if config.License != "MIT" { + t.Errorf("License = %q, want %q", config.License, "MIT") + } +} + +func TestModelfileConfig_Empty(t *testing.T) { + config := &ModelfileConfig{} + + if config.Template != "" { + t.Errorf("Template should be empty, got %q", config.Template) + } + if config.System != "" { + t.Errorf("System should be empty, got %q", config.System) + } + if config.License != "" { + t.Errorf("License should be empty, got %q", config.License) + } +} + +func TestModelfileConfig_PartialFields(t *testing.T) { + // Test config with only some fields set + config := &ModelfileConfig{ + Template: "{{ .Prompt }}", + // System and License intentionally empty + } + + if config.Template == "" { + t.Error("Template should not be empty") + } + if config.System != "" { + t.Error("System should be empty") + } + if config.License != "" { + t.Error("License should be empty") + } +} + +func TestMinOllamaVersion(t *testing.T) { + // Verify the minimum version constant is set + if MinOllamaVersion == "" { + t.Error("MinOllamaVersion should not be empty") + } + if MinOllamaVersion != "0.14.0" { + t.Errorf("MinOllamaVersion = %q, want %q", MinOllamaVersion, "0.14.0") + } +} + +func TestCreateModel_InvalidDir(t *testing.T) { + // Test that CreateModel returns error for invalid directory + err := CreateModel(CreateOptions{ + ModelName: "test-model", + ModelDir: "/nonexistent/path", + }, nil) + if err == nil { + t.Error("expected error for nonexistent directory, got nil") + } +} + +func TestCreateModel_NotSafetensorsDir(t *testing.T) { + // Test that CreateModel returns error for directory without safetensors + dir := t.TempDir() + + err := CreateModel(CreateOptions{ + ModelName: "test-model", + ModelDir: dir, + }, nil) + if err == nil { + t.Error("expected error for empty directory, got nil") + } +} + +func TestCreateOptions(t *testing.T) { + opts := CreateOptions{ + ModelName: "my-model", + ModelDir: "/path/to/model", + Quantize: "fp8", + Modelfile: &ModelfileConfig{ + Template: "test", + System: "system", + License: "MIT", + }, + } + + if opts.ModelName != "my-model" { + t.Errorf("ModelName = %q, want %q", opts.ModelName, "my-model") + } + if opts.ModelDir != "/path/to/model" { + t.Errorf("ModelDir = %q, want %q", opts.ModelDir, "/path/to/model") + } + if opts.Quantize != "fp8" { + t.Errorf("Quantize = %q, want %q", opts.Quantize, "fp8") + } + if opts.Modelfile == nil { + t.Error("Modelfile should not be nil") + } + if opts.Modelfile.Template != "test" { + t.Errorf("Modelfile.Template = %q, want %q", opts.Modelfile.Template, "test") + } +} + +func TestCreateOptions_Defaults(t *testing.T) { + opts := CreateOptions{ + ModelName: "test", + ModelDir: "/tmp", + } + + // Quantize should default to empty + if opts.Quantize != "" { + t.Errorf("Quantize should be empty by default, got %q", opts.Quantize) + } + + // Modelfile should default to nil + if opts.Modelfile != nil { + t.Error("Modelfile should be nil by default") + } +} + +func TestQuantizeSupported(t *testing.T) { + // This just verifies the function exists and returns a boolean + // The actual value depends on build tags (mlx vs non-mlx) + supported := QuantizeSupported() + + // In non-mlx builds, this should be false + // We can't easily test both cases, so just verify it returns something + _ = supported +} diff --git a/x/imagegen/client/quantize.go b/x/create/client/quantize.go similarity index 88% rename from x/imagegen/client/quantize.go rename to x/create/client/quantize.go index 569dc6baf..5a4be59d0 100644 --- a/x/imagegen/client/quantize.go +++ b/x/create/client/quantize.go @@ -11,10 +11,11 @@ import ( "github.com/ollama/ollama/x/imagegen/mlx" ) -// quantizeTensor loads a tensor from safetensors format, quantizes it to affine int8, +// quantizeTensor loads a tensor from safetensors format, quantizes it, // and returns safetensors data for the quantized weights, scales, and biases. +// Supported quantization types: "fp8" (affine 8-bit) // Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights). -func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) { +func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) { tmpDir := ensureTempDir() // Read safetensors data to a temp file (LoadSafetensorsNative needs a path) @@ -50,9 +51,15 @@ func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData mlx.Eval(arr) } - // Quantize with affine mode: group_size=32, bits=8 - // Note: mxfp8 mode doesn't have matmul kernels in MLX, affine mode does - qweight, scales, qbiases := mlx.Quantize(arr, 32, 8, "affine") + // Quantize based on quantization type + var qweight, scales, qbiases *mlx.Array + switch quantize { + case "fp8": + // affine mode: group_size=32, bits=8 + qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "affine") + default: + return nil, nil, nil, nil, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize) + } // Eval and make contiguous for data access qweight = mlx.Contiguous(qweight) diff --git a/x/imagegen/client/quantize_stub.go b/x/create/client/quantize_stub.go similarity index 75% rename from x/imagegen/client/quantize_stub.go rename to x/create/client/quantize_stub.go index cb992dd48..3a85afcc7 100644 --- a/x/imagegen/client/quantize_stub.go +++ b/x/create/client/quantize_stub.go @@ -8,7 +8,7 @@ import ( ) // quantizeTensor is not available without MLX -func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) { +func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) { return nil, nil, nil, nil, nil, nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)") } diff --git a/x/create/create.go b/x/create/create.go new file mode 100644 index 000000000..823d0f842 --- /dev/null +++ b/x/create/create.go @@ -0,0 +1,399 @@ +package create + +import ( + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + "slices" + "strings" + + "github.com/ollama/ollama/envconfig" + "github.com/ollama/ollama/x/imagegen/safetensors" +) + +// ModelConfig represents the config blob stored with a model. +type ModelConfig struct { + ModelFormat string `json:"model_format"` + Capabilities []string `json:"capabilities"` +} + +// Manifest represents the manifest JSON structure. +type Manifest struct { + SchemaVersion int `json:"schemaVersion"` + MediaType string `json:"mediaType"` + Config ManifestLayer `json:"config"` + Layers []ManifestLayer `json:"layers"` +} + +// ManifestLayer represents a layer in the manifest. +type ManifestLayer struct { + MediaType string `json:"mediaType"` + Digest string `json:"digest"` + Size int64 `json:"size"` + Name string `json:"name,omitempty"` +} + +// defaultManifestDir returns the manifest storage directory. +func defaultManifestDir() string { + return filepath.Join(envconfig.Models(), "manifests") +} + +// defaultBlobDir returns the blob storage directory. +func defaultBlobDir() string { + return filepath.Join(envconfig.Models(), "blobs") +} + +// resolveManifestPath converts a model name to a manifest file path. +func resolveManifestPath(modelName string) string { + host := "registry.ollama.ai" + namespace := "library" + name := modelName + tag := "latest" + + if idx := strings.LastIndex(name, ":"); idx != -1 { + tag = name[idx+1:] + name = name[:idx] + } + + parts := strings.Split(name, "/") + switch len(parts) { + case 3: + host = parts[0] + namespace = parts[1] + name = parts[2] + case 2: + namespace = parts[0] + name = parts[1] + } + + return filepath.Join(defaultManifestDir(), host, namespace, name, tag) +} + +// loadManifest loads a manifest for the given model name. +func loadManifest(modelName string) (*Manifest, error) { + manifestPath := resolveManifestPath(modelName) + + data, err := os.ReadFile(manifestPath) + if err != nil { + return nil, err + } + + var manifest Manifest + if err := json.Unmarshal(data, &manifest); err != nil { + return nil, err + } + + return &manifest, nil +} + +// loadModelConfig loads the config blob for a model. +func loadModelConfig(modelName string) (*ModelConfig, error) { + manifest, err := loadManifest(modelName) + if err != nil { + return nil, err + } + + // Read the config blob + blobName := strings.Replace(manifest.Config.Digest, ":", "-", 1) + blobPath := filepath.Join(defaultBlobDir(), blobName) + + data, err := os.ReadFile(blobPath) + if err != nil { + return nil, err + } + + var config ModelConfig + if err := json.Unmarshal(data, &config); err != nil { + return nil, err + } + + return &config, nil +} + +// IsSafetensorsModel checks if a model was created with the experimental +// safetensors builder by checking the model format in the config. +func IsSafetensorsModel(modelName string) bool { + config, err := loadModelConfig(modelName) + if err != nil { + return false + } + return config.ModelFormat == "safetensors" +} + +// IsSafetensorsLLMModel checks if a model is a safetensors LLM model +// (has completion capability, not image generation). +func IsSafetensorsLLMModel(modelName string) bool { + config, err := loadModelConfig(modelName) + if err != nil { + return false + } + return config.ModelFormat == "safetensors" && slices.Contains(config.Capabilities, "completion") +} + +// IsImageGenModel checks if a model is an image generation model +// (has image capability). +func IsImageGenModel(modelName string) bool { + config, err := loadModelConfig(modelName) + if err != nil { + return false + } + return config.ModelFormat == "safetensors" && slices.Contains(config.Capabilities, "image") +} + +// GetModelArchitecture returns the architecture from the model's config.json layer. +func GetModelArchitecture(modelName string) (string, error) { + manifest, err := loadManifest(modelName) + if err != nil { + return "", err + } + + // Find the config.json layer + for _, layer := range manifest.Layers { + if layer.Name == "config.json" && layer.MediaType == "application/vnd.ollama.image.json" { + blobName := strings.Replace(layer.Digest, ":", "-", 1) + blobPath := filepath.Join(defaultBlobDir(), blobName) + + data, err := os.ReadFile(blobPath) + if err != nil { + return "", err + } + + var cfg struct { + Architectures []string `json:"architectures"` + ModelType string `json:"model_type"` + } + if err := json.Unmarshal(data, &cfg); err != nil { + return "", err + } + + // Prefer model_type, fall back to first architecture + if cfg.ModelType != "" { + return cfg.ModelType, nil + } + if len(cfg.Architectures) > 0 { + return cfg.Architectures[0], nil + } + } + } + + return "", fmt.Errorf("architecture not found in model config") +} + +// IsTensorModelDir checks if the directory contains a diffusers-style tensor model +// by looking for model_index.json, which is the standard diffusers pipeline config. +func IsTensorModelDir(dir string) bool { + _, err := os.Stat(filepath.Join(dir, "model_index.json")) + return err == nil +} + +// IsSafetensorsModelDir checks if the directory contains a standard safetensors model +// by looking for config.json and at least one .safetensors file. +func IsSafetensorsModelDir(dir string) bool { + // Must have config.json + if _, err := os.Stat(filepath.Join(dir, "config.json")); err != nil { + return false + } + + // Must have at least one .safetensors file + entries, err := os.ReadDir(dir) + if err != nil { + return false + } + + for _, entry := range entries { + if strings.HasSuffix(entry.Name(), ".safetensors") { + return true + } + } + + return false +} + +// LayerInfo holds metadata for a created layer. +type LayerInfo struct { + Digest string + Size int64 + MediaType string + Name string // Path-style name: "component/tensor" or "path/to/config.json" +} + +// LayerCreator is called to create a blob layer. +// name is the path-style name (e.g., "tokenizer/tokenizer.json") +type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error) + +// TensorLayerCreator creates a tensor blob layer with metadata. +// name is the path-style name including component (e.g., "text_encoder/model.embed_tokens.weight") +type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error) + +// QuantizingTensorLayerCreator creates tensor layers with optional quantization. +// When quantize is non-empty (e.g., "fp8"), returns multiple layers (weight + scales + biases). +type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) + +// ManifestWriter writes the manifest file. +type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo) error + +// ShouldQuantize returns true if a tensor should be quantized. +// For image gen models (component non-empty): quantizes linear weights, skipping VAE, embeddings, norms. +// For LLM models (component empty): quantizes linear weights, skipping embeddings, norms, and small tensors. +func ShouldQuantize(name, component string) bool { + // Image gen specific: skip VAE entirely + if component == "vae" { + return false + } + + // Skip embeddings + if strings.Contains(name, "embed") { + return false + } + + // Skip layer norms and RMS norms + if strings.Contains(name, "norm") || strings.Contains(name, "ln_") || strings.Contains(name, "layernorm") { + return false + } + + // Skip biases + if strings.HasSuffix(name, ".bias") { + return false + } + + // Only quantize weights + return strings.HasSuffix(name, ".weight") +} + +// ShouldQuantizeTensor returns true if a tensor should be quantized based on name and shape. +// This is a more detailed check that also considers tensor dimensions. +func ShouldQuantizeTensor(name string, shape []int32) bool { + // Use basic name-based check first + if !ShouldQuantize(name, "") { + return false + } + + // Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any) + if len(shape) != 2 { + return false + } + + // Skip small tensors (less than 1024 elements) - not worth quantizing + if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 { + return false + } + + // MLX quantization requires last dimension to be divisible by group size (32) + if shape[len(shape)-1]%32 != 0 { + return false + } + + return true +} + +// CreateSafetensorsModel imports a standard safetensors model from a directory. +// This handles Hugging Face style models with config.json and *.safetensors files. +// Stores each tensor as a separate blob for fine-grained deduplication. +// If quantize is non-empty (e.g., "fp8"), eligible tensors will be quantized. +func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error { + var layers []LayerInfo + var configLayer LayerInfo + + entries, err := os.ReadDir(modelDir) + if err != nil { + return fmt.Errorf("failed to read directory: %w", err) + } + + // Process all safetensors files + for _, entry := range entries { + if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".safetensors") { + continue + } + + stPath := filepath.Join(modelDir, entry.Name()) + + // Extract individual tensors from safetensors file + extractor, err := safetensors.OpenForExtraction(stPath) + if err != nil { + return fmt.Errorf("failed to open %s: %w", stPath, err) + } + + tensorNames := extractor.ListTensors() + quantizeMsg := "" + if quantize != "" { + quantizeMsg = fmt.Sprintf(", quantizing to %s", quantize) + } + fn(fmt.Sprintf("importing %s (%d tensors%s)", entry.Name(), len(tensorNames), quantizeMsg)) + + for _, tensorName := range tensorNames { + td, err := extractor.GetTensor(tensorName) + if err != nil { + extractor.Close() + return fmt.Errorf("failed to get tensor %s: %w", tensorName, err) + } + + // Determine quantization type for this tensor (empty string if not quantizing) + quantizeType := "" + if quantize != "" && ShouldQuantizeTensor(tensorName, td.Shape) { + quantizeType = quantize + } + + // Store as minimal safetensors format (88 bytes header overhead) + // This enables native mmap loading via mlx_load_safetensors + // createTensorLayer returns multiple layers if quantizing (weight + scales) + newLayers, err := createTensorLayer(td.SafetensorsReader(), tensorName, td.Dtype, td.Shape, quantizeType) + if err != nil { + extractor.Close() + return fmt.Errorf("failed to create layer for %s: %w", tensorName, err) + } + layers = append(layers, newLayers...) + } + + extractor.Close() + } + + // Process all JSON config files + for _, entry := range entries { + if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") { + continue + } + + // Skip the index file as we don't need it after extraction + if entry.Name() == "model.safetensors.index.json" { + continue + } + + cfgPath := entry.Name() + fullPath := filepath.Join(modelDir, cfgPath) + + fn(fmt.Sprintf("importing config %s", cfgPath)) + + f, err := os.Open(fullPath) + if err != nil { + return fmt.Errorf("failed to open %s: %w", cfgPath, err) + } + + layer, err := createLayer(f, "application/vnd.ollama.image.json", cfgPath) + f.Close() + if err != nil { + return fmt.Errorf("failed to create layer for %s: %w", cfgPath, err) + } + + // Use config.json as the config layer + if cfgPath == "config.json" { + configLayer = layer + } + + layers = append(layers, layer) + } + + if configLayer.Digest == "" { + return fmt.Errorf("config.json not found in %s", modelDir) + } + + fn(fmt.Sprintf("writing manifest for %s", modelName)) + + if err := writeManifest(modelName, configLayer, layers); err != nil { + return fmt.Errorf("failed to write manifest: %w", err) + } + + fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers))) + return nil +} diff --git a/x/create/create_test.go b/x/create/create_test.go new file mode 100644 index 000000000..c69bb10a8 --- /dev/null +++ b/x/create/create_test.go @@ -0,0 +1,752 @@ +package create + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "io" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestIsTensorModelDir(t *testing.T) { + tests := []struct { + name string + setup func(dir string) error + expected bool + }{ + { + name: "valid diffusers model with model_index.json", + setup: func(dir string) error { + return os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(`{"_class_name": "FluxPipeline"}`), 0o644) + }, + expected: true, + }, + { + name: "empty directory", + setup: func(dir string) error { + return nil + }, + expected: false, + }, + { + name: "directory with other files but no model_index.json", + setup: func(dir string) error { + return os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644) + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + if err := tt.setup(dir); err != nil { + t.Fatalf("setup failed: %v", err) + } + + got := IsTensorModelDir(dir) + if got != tt.expected { + t.Errorf("IsTensorModelDir() = %v, want %v", got, tt.expected) + } + }) + } +} + +func TestIsSafetensorsModelDir(t *testing.T) { + tests := []struct { + name string + setup func(dir string) error + expected bool + }{ + { + name: "valid safetensors model with config.json and .safetensors file", + setup: func(dir string) error { + if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type": "gemma3"}`), 0o644); err != nil { + return err + } + return os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("dummy"), 0o644) + }, + expected: true, + }, + { + name: "config.json only, no safetensors files", + setup: func(dir string) error { + return os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644) + }, + expected: false, + }, + { + name: "safetensors file only, no config.json", + setup: func(dir string) error { + return os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("dummy"), 0o644) + }, + expected: false, + }, + { + name: "empty directory", + setup: func(dir string) error { + return nil + }, + expected: false, + }, + { + name: "multiple safetensors files with config.json", + setup: func(dir string) error { + if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644); err != nil { + return err + } + if err := os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("dummy"), 0o644); err != nil { + return err + } + return os.WriteFile(filepath.Join(dir, "model-00002-of-00002.safetensors"), []byte("dummy"), 0o644) + }, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + if err := tt.setup(dir); err != nil { + t.Fatalf("setup failed: %v", err) + } + + got := IsSafetensorsModelDir(dir) + if got != tt.expected { + t.Errorf("IsSafetensorsModelDir() = %v, want %v", got, tt.expected) + } + }) + } +} + +func TestIsSafetensorsModelDir_NonexistentDir(t *testing.T) { + got := IsSafetensorsModelDir("/nonexistent/path/that/does/not/exist") + if got != false { + t.Errorf("IsSafetensorsModelDir() = %v for nonexistent dir, want false", got) + } +} + +// createMinimalSafetensors creates a minimal valid safetensors file with one tensor +func createMinimalSafetensors(t *testing.T, path string) { + t.Helper() + + // Create a minimal safetensors file with a single float32 tensor + header := map[string]interface{}{ + "test_tensor": map[string]interface{}{ + "dtype": "F32", + "shape": []int{2, 2}, + "data_offsets": []int{0, 16}, // 4 float32 values = 16 bytes + }, + } + headerJSON, err := json.Marshal(header) + if err != nil { + t.Fatalf("failed to marshal header: %v", err) + } + + // Pad header to 8-byte alignment + padding := (8 - len(headerJSON)%8) % 8 + headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...) + + // Write file + f, err := os.Create(path) + if err != nil { + t.Fatalf("failed to create file: %v", err) + } + defer f.Close() + + // Write header size (8 bytes, little endian) + if err := binary.Write(f, binary.LittleEndian, uint64(len(headerJSON))); err != nil { + t.Fatalf("failed to write header size: %v", err) + } + + // Write header + if _, err := f.Write(headerJSON); err != nil { + t.Fatalf("failed to write header: %v", err) + } + + // Write tensor data (16 bytes of zeros for 4 float32 values) + if _, err := f.Write(make([]byte, 16)); err != nil { + t.Fatalf("failed to write tensor data: %v", err) + } +} + +func TestCreateSafetensorsModel(t *testing.T) { + dir := t.TempDir() + + // Create config.json + configJSON := `{"model_type": "test", "architectures": ["TestModel"]}` + if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil { + t.Fatalf("failed to write config.json: %v", err) + } + + // Create a minimal safetensors file + createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors")) + + // Track what was created + var createdLayers []LayerInfo + var manifestWritten bool + var manifestModelName string + var manifestConfigLayer LayerInfo + var manifestLayers []LayerInfo + var statusMessages []string + + // Mock callbacks + createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) { + data, err := io.ReadAll(r) + if err != nil { + return LayerInfo{}, err + } + layer := LayerInfo{ + Digest: "sha256:test", + Size: int64(len(data)), + MediaType: mediaType, + Name: name, + } + createdLayers = append(createdLayers, layer) + return layer, nil + } + + createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) { + data, err := io.ReadAll(r) + if err != nil { + return nil, err + } + layer := LayerInfo{ + Digest: "sha256:tensor_" + name, + Size: int64(len(data)), + MediaType: "application/vnd.ollama.image.tensor", + Name: name, + } + createdLayers = append(createdLayers, layer) + return []LayerInfo{layer}, nil + } + + writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { + manifestWritten = true + manifestModelName = modelName + manifestConfigLayer = config + manifestLayers = layers + return nil + } + + progressFn := func(status string) { + statusMessages = append(statusMessages, status) + } + + // Run CreateSafetensorsModel + err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn) + if err != nil { + t.Fatalf("CreateSafetensorsModel failed: %v", err) + } + + // Verify manifest was written + if !manifestWritten { + t.Error("manifest was not written") + } + + if manifestModelName != "test-model" { + t.Errorf("manifest model name = %q, want %q", manifestModelName, "test-model") + } + + // Verify config layer was set + if manifestConfigLayer.Name != "config.json" { + t.Errorf("config layer name = %q, want %q", manifestConfigLayer.Name, "config.json") + } + + // Verify we have at least one tensor and one config layer + hasTensor := false + hasConfig := false + for _, layer := range manifestLayers { + if layer.Name == "test_tensor" { + hasTensor = true + } + if layer.Name == "config.json" { + hasConfig = true + } + } + + if !hasTensor { + t.Error("no tensor layer found in manifest") + } + if !hasConfig { + t.Error("no config layer found in manifest") + } + + // Verify status messages were sent + if len(statusMessages) == 0 { + t.Error("no status messages received") + } +} + +func TestCreateSafetensorsModel_NoConfigJson(t *testing.T) { + dir := t.TempDir() + + // Create only a safetensors file, no config.json + createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors")) + + // Mock callbacks (minimal) + createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) { + io.ReadAll(r) + return LayerInfo{Name: name}, nil + } + createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) { + io.ReadAll(r) + return []LayerInfo{{Name: name}}, nil + } + writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { + return nil + } + progressFn := func(status string) {} + + err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn) + if err == nil { + t.Error("expected error for missing config.json, got nil") + } +} + +func TestCreateSafetensorsModel_EmptyDir(t *testing.T) { + dir := t.TempDir() + + // Mock callbacks + createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) { + return LayerInfo{}, nil + } + createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) { + return []LayerInfo{{}}, nil + } + writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { + return nil + } + progressFn := func(status string) {} + + err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn) + if err == nil { + t.Error("expected error for empty directory, got nil") + } +} + +func TestCreateSafetensorsModel_SkipsIndexJson(t *testing.T) { + dir := t.TempDir() + + // Create config.json + if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644); err != nil { + t.Fatalf("failed to write config.json: %v", err) + } + + // Create model.safetensors.index.json (should be skipped) + indexJSON := `{"metadata": {"total_size": 100}, "weight_map": {}}` + if err := os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(indexJSON), 0o644); err != nil { + t.Fatalf("failed to write index.json: %v", err) + } + + // Create a minimal safetensors file + createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors")) + + var configNames []string + + createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) { + io.ReadAll(r) + configNames = append(configNames, name) + return LayerInfo{Name: name, Digest: "sha256:test"}, nil + } + createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) { + io.ReadAll(r) + return []LayerInfo{{Name: name}}, nil + } + writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { + return nil + } + progressFn := func(status string) {} + + err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn) + if err != nil { + t.Fatalf("CreateSafetensorsModel failed: %v", err) + } + + // Verify model.safetensors.index.json was not included + for _, name := range configNames { + if name == "model.safetensors.index.json" { + t.Error("model.safetensors.index.json should have been skipped") + } + } +} + +func TestResolveManifestPath(t *testing.T) { + tests := []struct { + name string + modelName string + wantParts []string // Parts that should appear in the path + }{ + { + name: "simple model name", + modelName: "llama2", + wantParts: []string{"registry.ollama.ai", "library", "llama2", "latest"}, + }, + { + name: "model name with tag", + modelName: "llama2:7b", + wantParts: []string{"registry.ollama.ai", "library", "llama2", "7b"}, + }, + { + name: "model name with namespace", + modelName: "myuser/mymodel", + wantParts: []string{"registry.ollama.ai", "myuser", "mymodel", "latest"}, + }, + { + name: "model name with namespace and tag", + modelName: "myuser/mymodel:v1", + wantParts: []string{"registry.ollama.ai", "myuser", "mymodel", "v1"}, + }, + { + name: "fully qualified model name", + modelName: "registry.example.com/namespace/model:tag", + wantParts: []string{"registry.example.com", "namespace", "model", "tag"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := resolveManifestPath(tt.modelName) + + for _, part := range tt.wantParts { + if !strings.Contains(got, part) { + t.Errorf("resolveManifestPath(%q) = %q, missing part %q", tt.modelName, got, part) + } + } + }) + } +} + +func TestLayerInfo(t *testing.T) { + layer := LayerInfo{ + Digest: "sha256:abc123", + Size: 1024, + MediaType: "application/vnd.ollama.image.tensor", + Name: "model.weight", + } + + if layer.Digest != "sha256:abc123" { + t.Errorf("Digest = %q, want %q", layer.Digest, "sha256:abc123") + } + if layer.Size != 1024 { + t.Errorf("Size = %d, want %d", layer.Size, 1024) + } + if layer.MediaType != "application/vnd.ollama.image.tensor" { + t.Errorf("MediaType = %q, want %q", layer.MediaType, "application/vnd.ollama.image.tensor") + } + if layer.Name != "model.weight" { + t.Errorf("Name = %q, want %q", layer.Name, "model.weight") + } +} + +func TestModelConfig(t *testing.T) { + config := ModelConfig{ + ModelFormat: "safetensors", + Capabilities: []string{"completion", "chat"}, + } + + if config.ModelFormat != "safetensors" { + t.Errorf("ModelFormat = %q, want %q", config.ModelFormat, "safetensors") + } + if len(config.Capabilities) != 2 { + t.Errorf("Capabilities length = %d, want %d", len(config.Capabilities), 2) + } +} + +func TestManifest(t *testing.T) { + manifest := Manifest{ + SchemaVersion: 2, + MediaType: "application/vnd.oci.image.manifest.v1+json", + Config: ManifestLayer{ + MediaType: "application/vnd.docker.container.image.v1+json", + Digest: "sha256:config", + Size: 100, + }, + Layers: []ManifestLayer{ + { + MediaType: "application/vnd.ollama.image.tensor", + Digest: "sha256:layer1", + Size: 1000, + Name: "weight.bin", + }, + }, + } + + if manifest.SchemaVersion != 2 { + t.Errorf("SchemaVersion = %d, want %d", manifest.SchemaVersion, 2) + } + if manifest.Config.Digest != "sha256:config" { + t.Errorf("Config.Digest = %q, want %q", manifest.Config.Digest, "sha256:config") + } + if len(manifest.Layers) != 1 { + t.Errorf("Layers length = %d, want %d", len(manifest.Layers), 1) + } + if manifest.Layers[0].Name != "weight.bin" { + t.Errorf("Layers[0].Name = %q, want %q", manifest.Layers[0].Name, "weight.bin") + } +} + +func TestShouldQuantize(t *testing.T) { + tests := []struct { + name string + tensor string + component string + want bool + }{ + // VAE component should never be quantized + {"vae weight", "decoder.weight", "vae", false}, + {"vae bias", "decoder.bias", "vae", false}, + + // Embeddings should not be quantized + {"embedding weight", "embed_tokens.weight", "", false}, + {"embedding in name", "token_embedding.weight", "", false}, + + // Norms should not be quantized + {"layer norm", "layer_norm.weight", "", false}, + {"rms norm", "rms_norm.weight", "", false}, + {"ln prefix", "ln_1.weight", "", false}, + {"layernorm in name", "input_layernorm.weight", "", false}, + + // Biases should not be quantized + {"bias tensor", "attention.bias", "", false}, + {"proj bias", "o_proj.bias", "", false}, + + // Linear weights should be quantized + {"linear weight", "q_proj.weight", "", true}, + {"attention weight", "self_attn.weight", "", true}, + {"mlp weight", "mlp.gate_proj.weight", "", true}, + + // Transformer component weights should be quantized + {"transformer weight", "layers.0.weight", "transformer", true}, + {"text_encoder weight", "encoder.weight", "text_encoder", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ShouldQuantize(tt.tensor, tt.component) + if got != tt.want { + t.Errorf("ShouldQuantize(%q, %q) = %v, want %v", tt.tensor, tt.component, got, tt.want) + } + }) + } +} + +func TestShouldQuantizeTensor(t *testing.T) { + tests := []struct { + name string + tensor string + shape []int32 + want bool + }{ + // 2D tensors with sufficient size should be quantized + {"large 2D weight", "q_proj.weight", []int32{4096, 4096}, true}, + {"medium 2D weight", "small_proj.weight", []int32{128, 128}, true}, + + // Small tensors should not be quantized (< 1024 elements) + {"tiny 2D weight", "tiny.weight", []int32{16, 16}, false}, + {"small 2D weight", "small.weight", []int32{31, 31}, false}, + + // 1D tensors should not be quantized + {"1D tensor", "layer_norm.weight", []int32{4096}, false}, + + // 3D+ tensors should not be quantized + {"3D tensor", "conv.weight", []int32{64, 64, 3}, false}, + {"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, false}, + + // Embeddings should not be quantized regardless of shape + {"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, false}, + + // Norms should not be quantized regardless of shape + {"norm 2D", "layer_norm.weight", []int32{4096, 1}, false}, + + // Biases should not be quantized + {"bias 2D", "proj.bias", []int32{4096, 1}, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ShouldQuantizeTensor(tt.tensor, tt.shape) + if got != tt.want { + t.Errorf("ShouldQuantizeTensor(%q, %v) = %v, want %v", tt.tensor, tt.shape, got, tt.want) + } + }) + } +} + +func TestCreateSafetensorsModel_WithQuantize(t *testing.T) { + dir := t.TempDir() + + // Create config.json + configJSON := `{"model_type": "test", "architectures": ["TestModel"]}` + if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil { + t.Fatalf("failed to write config.json: %v", err) + } + + // Create a minimal safetensors file + createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors")) + + var quantizeRequested []string + + createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) { + io.ReadAll(r) + return LayerInfo{Name: name, Digest: "sha256:test"}, nil + } + + createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) { + io.ReadAll(r) + quantizeRequested = append(quantizeRequested, quantize) + return []LayerInfo{{Name: name}}, nil + } + + writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { + return nil + } + + progressFn := func(status string) {} + + // Run with quantize enabled + err := CreateSafetensorsModel("test-model", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn) + if err != nil { + t.Fatalf("CreateSafetensorsModel failed: %v", err) + } + + // Verify quantize was passed to callback (will be false for small test tensor) + if len(quantizeRequested) == 0 { + t.Error("no tensors processed") + } +} + +// createMinimalImageGenModel creates a minimal diffusers-style model directory +func createMinimalImageGenModel(t *testing.T, dir string) { + t.Helper() + + // Create model_index.json + modelIndex := `{"_class_name": "FluxPipeline", "_diffusers_version": "0.30.0"}` + if err := os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(modelIndex), 0o644); err != nil { + t.Fatalf("failed to write model_index.json: %v", err) + } + + // Create transformer directory with a safetensors file + transformerDir := filepath.Join(dir, "transformer") + if err := os.MkdirAll(transformerDir, 0o755); err != nil { + t.Fatalf("failed to create transformer dir: %v", err) + } + createMinimalSafetensors(t, filepath.Join(transformerDir, "model.safetensors")) + + // Create transformer config + transformerConfig := `{"hidden_size": 3072}` + if err := os.WriteFile(filepath.Join(transformerDir, "config.json"), []byte(transformerConfig), 0o644); err != nil { + t.Fatalf("failed to write transformer config: %v", err) + } +} + +func TestCreateImageGenModel(t *testing.T) { + dir := t.TempDir() + createMinimalImageGenModel(t, dir) + + var manifestWritten bool + var manifestModelName string + var statusMessages []string + + createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) { + io.ReadAll(r) + return LayerInfo{Name: name, Digest: "sha256:test"}, nil + } + + createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) { + io.ReadAll(r) + return []LayerInfo{{Name: name, Digest: "sha256:tensor"}}, nil + } + + writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { + manifestWritten = true + manifestModelName = modelName + return nil + } + + progressFn := func(status string) { + statusMessages = append(statusMessages, status) + } + + err := CreateImageGenModel("test-imagegen", dir, "", createLayer, createTensorLayer, writeManifest, progressFn) + if err != nil { + t.Fatalf("CreateImageGenModel failed: %v", err) + } + + if !manifestWritten { + t.Error("manifest was not written") + } + + if manifestModelName != "test-imagegen" { + t.Errorf("manifest model name = %q, want %q", manifestModelName, "test-imagegen") + } + + if len(statusMessages) == 0 { + t.Error("no status messages received") + } +} + +func TestCreateImageGenModel_NoModelIndex(t *testing.T) { + dir := t.TempDir() + + // Create only transformer without model_index.json + transformerDir := filepath.Join(dir, "transformer") + if err := os.MkdirAll(transformerDir, 0o755); err != nil { + t.Fatalf("failed to create transformer dir: %v", err) + } + createMinimalSafetensors(t, filepath.Join(transformerDir, "model.safetensors")) + + createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) { + io.ReadAll(r) + return LayerInfo{Name: name}, nil + } + createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) { + io.ReadAll(r) + return []LayerInfo{{Name: name}}, nil + } + writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { + return nil + } + progressFn := func(status string) {} + + err := CreateImageGenModel("test-imagegen", dir, "", createLayer, createTensorLayer, writeManifest, progressFn) + if err == nil { + t.Error("expected error for missing model_index.json, got nil") + } +} + +func TestCreateImageGenModel_WithQuantize(t *testing.T) { + dir := t.TempDir() + createMinimalImageGenModel(t, dir) + + var quantizeRequested []string + + createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) { + io.ReadAll(r) + return LayerInfo{Name: name, Digest: "sha256:test"}, nil + } + + createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) { + io.ReadAll(r) + quantizeRequested = append(quantizeRequested, quantize) + return []LayerInfo{{Name: name}}, nil + } + + writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { + return nil + } + + progressFn := func(status string) {} + + err := CreateImageGenModel("test-imagegen", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn) + if err != nil { + t.Fatalf("CreateImageGenModel failed: %v", err) + } + + if len(quantizeRequested) == 0 { + t.Error("no tensors processed") + } +} diff --git a/x/imagegen/create.go b/x/create/imagegen.go similarity index 70% rename from x/imagegen/create.go rename to x/create/imagegen.go index c2e22d3df..ad10d8c69 100644 --- a/x/imagegen/create.go +++ b/x/create/imagegen.go @@ -1,4 +1,4 @@ -package imagegen +package create import ( "bytes" @@ -12,40 +12,24 @@ import ( "github.com/ollama/ollama/x/imagegen/safetensors" ) -// IsTensorModelDir checks if the directory contains a tensor model -// by looking for model_index.json, which is the standard diffusers pipeline config. -func IsTensorModelDir(dir string) bool { - _, err := os.Stat(filepath.Join(dir, "model_index.json")) - return err == nil -} - -// LayerInfo holds metadata for a created layer. -type LayerInfo struct { - Digest string - Size int64 - MediaType string - Name string // Path-style name: "component/tensor" or "path/to/config.json" -} - -// LayerCreator is called to create a blob layer. -// name is the path-style name (e.g., "tokenizer/tokenizer.json") -type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error) - -// TensorLayerCreator creates a tensor blob layer with metadata. -// name is the path-style name including component (e.g., "text_encoder/model.embed_tokens.weight") -type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error) - -// ManifestWriter writes the manifest file. -type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo) error - -// CreateModel imports an image generation model from a directory. +// CreateImageGenModel imports an image generation model from a directory. // Stores each tensor as a separate blob for fine-grained deduplication. -// If quantize is "fp8", linear weights in transformer/text_encoder are quantized to mxfp8 format. +// If quantize is specified, linear weights in transformer/text_encoder are quantized. +// Supported quantization types: fp8 (or empty for no quantization). // Layer creation and manifest writing are done via callbacks to avoid import cycles. -func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error { +func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error { + // Validate quantization type + switch quantize { + case "", "fp8": + // valid + default: + return fmt.Errorf("unsupported quantization type %q: supported types are fp8", quantize) + } + var layers []LayerInfo var configLayer LayerInfo var totalParams int64 // Count parameters from original tensor shapes + var torchDtype string // Read from component config for quantization display // Components to process - extract individual tensors from each components := []string{"text_encoder", "transformer", "vae"} @@ -77,8 +61,8 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator, tensorNames := extractor.ListTensors() quantizeMsg := "" - if quantize == "fp8" && component != "vae" { - quantizeMsg = ", quantizing to fp8" + if quantize != "" && component != "vae" { + quantizeMsg = ", quantizing to " + quantize } fn(fmt.Sprintf("importing %s/%s (%d tensors%s)", component, entry.Name(), len(tensorNames), quantizeMsg)) @@ -103,11 +87,14 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator, // Use path-style name: "component/tensor_name" fullName := component + "/" + tensorName - // Determine if this tensor should be quantized - doQuantize := quantize == "fp8" && ShouldQuantize(tensorName, component) + // Determine quantization type for this tensor (empty string if not quantizing) + quantizeType := "" + if quantize != "" && ShouldQuantize(tensorName, component) && canQuantizeShape(td.Shape) { + quantizeType = quantize + } // createTensorLayer returns multiple layers if quantizing (weight + scales) - newLayers, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape, doQuantize) + newLayers, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape, quantizeType) if err != nil { extractor.Close() return fmt.Errorf("failed to create layer for %s: %w", fullName, err) @@ -119,6 +106,19 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator, } } + // Read torch_dtype from text_encoder config for quantization display + if torchDtype == "" { + textEncoderConfig := filepath.Join(modelDir, "text_encoder/config.json") + if data, err := os.ReadFile(textEncoderConfig); err == nil { + var cfg struct { + TorchDtype string `json:"torch_dtype"` + } + if json.Unmarshal(data, &cfg) == nil && cfg.TorchDtype != "" { + torchDtype = cfg.TorchDtype + } + } + } + // Import config files configFiles := []string{ "model_index.json", @@ -164,11 +164,11 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator, // Add parameter count (counted from tensor shapes during import) cfg["parameter_count"] = totalParams - // Add quantization info - if quantize == "fp8" { - cfg["quantization"] = "FP8" + // Add quantization info - use quantize type if set, otherwise torch_dtype + if quantize != "" { + cfg["quantization"] = strings.ToUpper(quantize) } else { - cfg["quantization"] = "BF16" + cfg["quantization"] = torchDtype } data, err = json.MarshalIndent(cfg, "", " ") @@ -211,3 +211,12 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator, fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers))) return nil } + +// canQuantizeShape returns true if a tensor shape is compatible with MLX quantization. +// MLX requires the last dimension to be divisible by the group size (32). +func canQuantizeShape(shape []int32) bool { + if len(shape) < 2 { + return false + } + return shape[len(shape)-1]%32 == 0 +} diff --git a/x/imagegen/client/create.go b/x/imagegen/client/create.go deleted file mode 100644 index f4a3d50d9..000000000 --- a/x/imagegen/client/create.go +++ /dev/null @@ -1,190 +0,0 @@ -// Package client provides client-side model creation for tensor-based models. -// -// This package is in x/ because the tensor model storage format is under development. -// It also exists to break an import cycle: server imports x/imagegen, so x/imagegen -// cannot import server. This sub-package can import server because server doesn't -// import it. -// -// TODO (jmorganca): This is temporary. When tensor models are promoted to production: -// 1. Add proper API endpoints for tensor model creation -// 2. Move tensor extraction to server-side -// 3. Remove this package -// 4. Follow the same client→server pattern as regular model creation -package client - -import ( - "bytes" - "encoding/json" - "fmt" - "io" - - "github.com/ollama/ollama/progress" - "github.com/ollama/ollama/server" - "github.com/ollama/ollama/types/model" - "github.com/ollama/ollama/x/imagegen" -) - -// MinOllamaVersion is the minimum Ollama version required for image generation models. -const MinOllamaVersion = "0.14.0" - -// CreateModel imports a tensor-based model from a local directory. -// This creates blobs and manifest directly on disk, bypassing the HTTP API. -// If quantize is "fp8", weights will be quantized to mxfp8 format during import. -// -// TODO (jmorganca): Replace with API-based creation when promoted to production. -func CreateModel(modelName, modelDir, quantize string, p *progress.Progress) error { - if !imagegen.IsTensorModelDir(modelDir) { - return fmt.Errorf("%s is not an image generation model directory (model_index.json not found)", modelDir) - } - - status := "importing image generation model" - spinner := progress.NewSpinner(status) - p.Add("imagegen", spinner) - - // Create layer callback for config files - createLayer := func(r io.Reader, mediaType, name string) (imagegen.LayerInfo, error) { - layer, err := server.NewLayer(r, mediaType) - if err != nil { - return imagegen.LayerInfo{}, err - } - layer.Name = name - - return imagegen.LayerInfo{ - Digest: layer.Digest, - Size: layer.Size, - MediaType: layer.MediaType, - Name: name, - }, nil - } - - // Create tensor layer callback for individual tensors - // name is path-style: "component/tensor_name" - // When quantize is true, returns multiple layers (weight + scales) - createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, doQuantize bool) ([]imagegen.LayerInfo, error) { - if doQuantize { - // Check if quantization is supported - if !QuantizeSupported() { - return nil, fmt.Errorf("quantization requires MLX support") - } - - // Quantize the tensor (affine mode returns weight, scales, qbiases) - qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape) - if err != nil { - return nil, fmt.Errorf("failed to quantize %s: %w", name, err) - } - - // Create layer for quantized weight - weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor) - if err != nil { - return nil, err - } - - // Create layer for scales (use _scale suffix convention) - scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor) - if err != nil { - return nil, err - } - - layers := []imagegen.LayerInfo{ - { - Digest: weightLayer.Digest, - Size: weightLayer.Size, - MediaType: weightLayer.MediaType, - Name: name, // Keep original name for weight - }, - { - Digest: scalesLayer.Digest, - Size: scalesLayer.Size, - MediaType: scalesLayer.MediaType, - Name: name + "_scale", // Add _scale suffix - }, - } - - // Add qbiases layer if present (affine mode) - if qbiasData != nil { - qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor) - if err != nil { - return nil, err - } - layers = append(layers, imagegen.LayerInfo{ - Digest: qbiasLayer.Digest, - Size: qbiasLayer.Size, - MediaType: qbiasLayer.MediaType, - Name: name + "_qbias", // Add _qbias suffix - }) - } - - return layers, nil - } - - // Non-quantized path: just create a single layer - layer, err := server.NewLayer(r, server.MediaTypeImageTensor) - if err != nil { - return nil, err - } - - return []imagegen.LayerInfo{ - { - Digest: layer.Digest, - Size: layer.Size, - MediaType: layer.MediaType, - Name: name, - }, - }, nil - } - - // Create manifest writer callback - writeManifest := func(modelName string, config imagegen.LayerInfo, layers []imagegen.LayerInfo) error { - name := model.ParseName(modelName) - if !name.IsValid() { - return fmt.Errorf("invalid model name: %s", modelName) - } - - // Create a proper config blob with version requirement - configData := model.ConfigV2{ - ModelFormat: "safetensors", - Capabilities: []string{"image"}, - Requires: MinOllamaVersion, - } - configJSON, err := json.Marshal(configData) - if err != nil { - return fmt.Errorf("failed to marshal config: %w", err) - } - - // Create config layer blob - configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json") - if err != nil { - return fmt.Errorf("failed to create config layer: %w", err) - } - - // Convert LayerInfo to server.Layer (include the original model_index.json in layers) - serverLayers := make([]server.Layer, len(layers)) - for i, l := range layers { - serverLayers[i] = server.Layer{ - MediaType: l.MediaType, - Digest: l.Digest, - Size: l.Size, - Name: l.Name, - } - } - - return server.WriteManifest(name, configLayer, serverLayers) - } - - // Progress callback - progressFn := func(msg string) { - spinner.Stop() - status = msg - spinner = progress.NewSpinner(status) - p.Add("imagegen", spinner) - } - - err := imagegen.CreateModel(modelName, modelDir, quantize, createLayer, createTensorLayer, writeManifest, progressFn) - spinner.Stop() - if err != nil { - return err - } - - fmt.Printf("Created image generation model '%s'\n", modelName) - return nil -} diff --git a/x/imagegen/quantize.go b/x/imagegen/quantize.go deleted file mode 100644 index 09f815caf..000000000 --- a/x/imagegen/quantize.go +++ /dev/null @@ -1,22 +0,0 @@ -package imagegen - -import ( - "io" - "strings" -) - -// QuantizingTensorLayerCreator creates tensor layers with optional quantization. -// When quantize is true, returns multiple layers (weight + scales + biases). -type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error) - -// ShouldQuantize returns true if a tensor should be quantized. -// Quantizes linear weights only, skipping VAE, embeddings, norms, and biases. -func ShouldQuantize(name, component string) bool { - if component == "vae" { - return false - } - if strings.Contains(name, "embed") || strings.Contains(name, "norm") { - return false - } - return strings.HasSuffix(name, ".weight") -} diff --git a/x/server/show.go b/x/server/show.go new file mode 100644 index 000000000..8cadb2c62 --- /dev/null +++ b/x/server/show.go @@ -0,0 +1,284 @@ +package server + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "io" + "os" + "strings" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/x/imagegen" +) + +// modelConfig represents the HuggingFace config.json structure +type modelConfig struct { + Architectures []string `json:"architectures"` + ModelType string `json:"model_type"` + HiddenSize int `json:"hidden_size"` + NumHiddenLayers int `json:"num_hidden_layers"` + MaxPositionEmbeddings int `json:"max_position_embeddings"` + IntermediateSize int `json:"intermediate_size"` + NumAttentionHeads int `json:"num_attention_heads"` + NumKeyValueHeads int `json:"num_key_value_heads"` + VocabSize int `json:"vocab_size"` + RMSNormEps float64 `json:"rms_norm_eps"` + RopeTheta float64 `json:"rope_theta"` + TorchDtype string `json:"torch_dtype"` + TextConfig *struct { + HiddenSize int `json:"hidden_size"` + MaxPositionEmbeddings int `json:"max_position_embeddings"` + NumHiddenLayers int `json:"num_hidden_layers"` + } `json:"text_config"` +} + +// GetSafetensorsLLMInfo extracts model information from safetensors LLM models. +// It reads the config.json layer and returns a map compatible with GGML's KV format. +func GetSafetensorsLLMInfo(modelName string) (map[string]any, error) { + manifest, err := imagegen.LoadManifest(modelName) + if err != nil { + return nil, fmt.Errorf("failed to load manifest: %w", err) + } + + var config modelConfig + if err := manifest.ReadConfigJSON("config.json", &config); err != nil { + return nil, fmt.Errorf("failed to read config.json: %w", err) + } + + // Calculate total tensor bytes from manifest layers + var totalBytes int64 + var tensorCount int64 + for _, layer := range manifest.Manifest.Layers { + if layer.MediaType == "application/vnd.ollama.image.tensor" { + totalBytes += layer.Size + tensorCount++ + } + } + + return buildModelInfo(config, totalBytes, tensorCount), nil +} + +// buildModelInfo constructs the model info map from config and tensor stats. +// This is separated for testability. +func buildModelInfo(config modelConfig, totalTensorBytes, tensorCount int64) map[string]any { + // Determine architecture + arch := config.ModelType + if arch == "" && len(config.Architectures) > 0 { + // Convert HuggingFace architecture name to Ollama format + // e.g., "Gemma3ForCausalLM" -> "gemma3" + hfArch := config.Architectures[0] + arch = strings.ToLower(hfArch) + arch = strings.TrimSuffix(arch, "forcausallm") + arch = strings.TrimSuffix(arch, "forconditionalgeneration") + } + + // Use text_config values if they exist (for multimodal models) + hiddenSize := config.HiddenSize + maxPosEmbed := config.MaxPositionEmbeddings + numLayers := config.NumHiddenLayers + + if config.TextConfig != nil { + if config.TextConfig.HiddenSize > 0 { + hiddenSize = config.TextConfig.HiddenSize + } + if config.TextConfig.MaxPositionEmbeddings > 0 { + maxPosEmbed = config.TextConfig.MaxPositionEmbeddings + } + if config.TextConfig.NumHiddenLayers > 0 { + numLayers = config.TextConfig.NumHiddenLayers + } + } + + // Get dtype to determine bytes per parameter for count calculation + dtype := config.TorchDtype + + // Determine bytes per parameter based on dtype + var bytesPerParam int64 = 2 // default to float16/bfloat16 + switch strings.ToLower(dtype) { + case "float32": + bytesPerParam = 4 + case "float16", "bfloat16": + bytesPerParam = 2 + case "int8", "uint8": + bytesPerParam = 1 + } + + // Subtract safetensors header overhead (88 bytes per tensor file) + // Each tensor is stored as a minimal safetensors file + totalBytes := totalTensorBytes - tensorCount*88 + + paramCount := totalBytes / bytesPerParam + + info := map[string]any{ + "general.architecture": arch, + } + + if maxPosEmbed > 0 { + info[fmt.Sprintf("%s.context_length", arch)] = maxPosEmbed + } + + if hiddenSize > 0 { + info[fmt.Sprintf("%s.embedding_length", arch)] = hiddenSize + } + + if numLayers > 0 { + info[fmt.Sprintf("%s.block_count", arch)] = numLayers + } + + if config.NumAttentionHeads > 0 { + info[fmt.Sprintf("%s.attention.head_count", arch)] = config.NumAttentionHeads + } + + if config.NumKeyValueHeads > 0 { + info[fmt.Sprintf("%s.attention.head_count_kv", arch)] = config.NumKeyValueHeads + } + + if config.IntermediateSize > 0 { + info[fmt.Sprintf("%s.feed_forward_length", arch)] = config.IntermediateSize + } + + if config.VocabSize > 0 { + info[fmt.Sprintf("%s.vocab_size", arch)] = config.VocabSize + } + + if paramCount > 0 { + info["general.parameter_count"] = paramCount + } + + return info +} + +// GetSafetensorsTensorInfo extracts tensor information from safetensors model layers. +// Each tensor is stored as a minimal safetensors file with an 88-byte header containing metadata. +func GetSafetensorsTensorInfo(modelName string) ([]api.Tensor, error) { + manifest, err := imagegen.LoadManifest(modelName) + if err != nil { + return nil, fmt.Errorf("failed to load manifest: %w", err) + } + + return getTensorInfoFromManifest(manifest) +} + +// getTensorInfoFromManifest extracts tensor info from a manifest. +// This is separated for testability. +func getTensorInfoFromManifest(manifest *imagegen.ModelManifest) ([]api.Tensor, error) { + var tensors []api.Tensor + + for _, layer := range manifest.Manifest.Layers { + if layer.MediaType != "application/vnd.ollama.image.tensor" { + continue + } + + // Read the safetensors header from the blob + blobPath := manifest.BlobPath(layer.Digest) + info, err := readSafetensorsHeader(blobPath) + if err != nil { + // Skip tensors we can't read + continue + } + + // Convert shape from int to uint64 + shape := make([]uint64, len(info.Shape)) + for i, s := range info.Shape { + shape[i] = uint64(s) + } + + tensors = append(tensors, api.Tensor{ + Name: layer.Name, + Type: info.Dtype, + Shape: shape, + }) + } + + return tensors, nil +} + +// GetSafetensorsDtype returns the quantization type for a safetensors model. +// If the model is quantized (has _scale tensors), returns the quantization type (e.g., "FP8"). +// Otherwise returns the torch_dtype from config.json. +func GetSafetensorsDtype(modelName string) (string, error) { + manifest, err := imagegen.LoadManifest(modelName) + if err != nil { + return "", fmt.Errorf("failed to load manifest: %w", err) + } + + // Check if model is quantized by looking for _scale tensors + for _, layer := range manifest.Manifest.Layers { + if layer.MediaType == "application/vnd.ollama.image.tensor" { + if strings.HasSuffix(layer.Name, "_scale") { + // Model is quantized - return FP8 (affine quantization) + return "FP8", nil + } + } + } + + // Not quantized - return torch_dtype from config.json + var cfg struct { + TorchDtype string `json:"torch_dtype"` + } + if err := manifest.ReadConfigJSON("config.json", &cfg); err != nil { + return "", fmt.Errorf("failed to read config.json: %w", err) + } + + return cfg.TorchDtype, nil +} + +// safetensorsTensorInfo holds metadata about a tensor from a safetensors header +type safetensorsTensorInfo struct { + Dtype string `json:"dtype"` + Shape []int64 `json:"shape"` +} + +// readSafetensorsHeader reads the JSON header from a safetensors file to get tensor metadata. +// Safetensors format: 8-byte header size (little endian) + JSON header + tensor data +func readSafetensorsHeader(path string) (*safetensorsTensorInfo, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + return parseSafetensorsHeader(f) +} + +// parseSafetensorsHeader parses a safetensors header from a reader. +// This is separated for testability. +func parseSafetensorsHeader(r io.Reader) (*safetensorsTensorInfo, error) { + // Read header size (8 bytes, little endian) + var headerSize uint64 + if err := binary.Read(r, binary.LittleEndian, &headerSize); err != nil { + return nil, fmt.Errorf("failed to read header size: %w", err) + } + + // Sanity check - header shouldn't be too large + if headerSize > 1024*1024 { + return nil, fmt.Errorf("header size too large: %d", headerSize) + } + + // Read header JSON + headerBytes := make([]byte, headerSize) + if _, err := io.ReadFull(r, headerBytes); err != nil { + return nil, fmt.Errorf("failed to read header: %w", err) + } + + // Parse as map of tensor name -> info + var header map[string]json.RawMessage + if err := json.Unmarshal(headerBytes, &header); err != nil { + return nil, fmt.Errorf("failed to parse header: %w", err) + } + + // Find the first (and should be only) tensor entry + for name, raw := range header { + if name == "__metadata__" { + continue + } + var info safetensorsTensorInfo + if err := json.Unmarshal(raw, &info); err != nil { + return nil, fmt.Errorf("failed to parse tensor info: %w", err) + } + return &info, nil + } + + return nil, fmt.Errorf("no tensor found in header") +} diff --git a/x/server/show_test.go b/x/server/show_test.go new file mode 100644 index 000000000..c510b0d54 --- /dev/null +++ b/x/server/show_test.go @@ -0,0 +1,597 @@ +package server + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "os" + "path/filepath" + "testing" + + "github.com/ollama/ollama/x/imagegen" +) + +func TestBuildModelInfo(t *testing.T) { + tests := []struct { + name string + config modelConfig + totalTensorBytes int64 + tensorCount int64 + wantArch string + wantContextLen int + wantEmbedLen int + wantBlockCount int + wantParamCount int64 + }{ + { + name: "gemma3 model with model_type", + config: modelConfig{ + ModelType: "gemma3", + HiddenSize: 2560, + NumHiddenLayers: 34, + MaxPositionEmbeddings: 131072, + IntermediateSize: 10240, + NumAttentionHeads: 8, + NumKeyValueHeads: 4, + VocabSize: 262144, + TorchDtype: "bfloat16", + }, + totalTensorBytes: 8_600_000_088, // ~4.3B params * 2 bytes + 88 bytes header + tensorCount: 1, + wantArch: "gemma3", + wantContextLen: 131072, + wantEmbedLen: 2560, + wantBlockCount: 34, + wantParamCount: 4_300_000_000, + }, + { + name: "llama model with architectures array", + config: modelConfig{ + Architectures: []string{"LlamaForCausalLM"}, + HiddenSize: 4096, + NumHiddenLayers: 32, + MaxPositionEmbeddings: 4096, + IntermediateSize: 11008, + NumAttentionHeads: 32, + NumKeyValueHeads: 32, + VocabSize: 32000, + TorchDtype: "float16", + }, + totalTensorBytes: 14_000_000_088, // ~7B params * 2 bytes + 88 bytes header + tensorCount: 1, + wantArch: "llama", + wantContextLen: 4096, + wantEmbedLen: 4096, + wantBlockCount: 32, + wantParamCount: 7_000_000_000, + }, + { + name: "multimodal model with text_config", + config: modelConfig{ + Architectures: []string{"Gemma3ForConditionalGeneration"}, + HiddenSize: 1152, // vision hidden size + TextConfig: &struct { + HiddenSize int `json:"hidden_size"` + MaxPositionEmbeddings int `json:"max_position_embeddings"` + NumHiddenLayers int `json:"num_hidden_layers"` + }{ + HiddenSize: 2560, + MaxPositionEmbeddings: 131072, + NumHiddenLayers: 34, + }, + NumAttentionHeads: 8, + NumKeyValueHeads: 4, + VocabSize: 262144, + TorchDtype: "bfloat16", + }, + totalTensorBytes: 8_600_000_088, + tensorCount: 1, + wantArch: "gemma3", + wantContextLen: 131072, + wantEmbedLen: 2560, + wantBlockCount: 34, + wantParamCount: 4_300_000_000, + }, + { + name: "float32 model", + config: modelConfig{ + ModelType: "test", + HiddenSize: 512, + NumHiddenLayers: 6, + MaxPositionEmbeddings: 2048, + TorchDtype: "float32", + }, + totalTensorBytes: 400_000_088, // 100M params * 4 bytes + 88 bytes header + tensorCount: 1, + wantArch: "test", + wantContextLen: 2048, + wantEmbedLen: 512, + wantBlockCount: 6, + wantParamCount: 100_000_000, + }, + { + name: "multiple tensors with header overhead", + config: modelConfig{ + ModelType: "test", + HiddenSize: 256, + NumHiddenLayers: 4, + MaxPositionEmbeddings: 1024, + TorchDtype: "bfloat16", + }, + totalTensorBytes: 2_000_880, // 1M params * 2 bytes + 10 tensors * 88 bytes + tensorCount: 10, + wantArch: "test", + wantContextLen: 1024, + wantEmbedLen: 256, + wantBlockCount: 4, + wantParamCount: 1_000_000, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + info := buildModelInfo(tt.config, tt.totalTensorBytes, tt.tensorCount) + + // Check architecture + if arch, ok := info["general.architecture"].(string); !ok || arch != tt.wantArch { + t.Errorf("architecture = %v, want %v", info["general.architecture"], tt.wantArch) + } + + // Check context length + contextKey := tt.wantArch + ".context_length" + if contextLen, ok := info[contextKey].(int); !ok || contextLen != tt.wantContextLen { + t.Errorf("context_length = %v, want %v", info[contextKey], tt.wantContextLen) + } + + // Check embedding length + embedKey := tt.wantArch + ".embedding_length" + if embedLen, ok := info[embedKey].(int); !ok || embedLen != tt.wantEmbedLen { + t.Errorf("embedding_length = %v, want %v", info[embedKey], tt.wantEmbedLen) + } + + // Check block count + blockKey := tt.wantArch + ".block_count" + if blockCount, ok := info[blockKey].(int); !ok || blockCount != tt.wantBlockCount { + t.Errorf("block_count = %v, want %v", info[blockKey], tt.wantBlockCount) + } + + // Check parameter count + if paramCount, ok := info["general.parameter_count"].(int64); !ok || paramCount != tt.wantParamCount { + t.Errorf("parameter_count = %v, want %v", info["general.parameter_count"], tt.wantParamCount) + } + }) + } +} + +func TestBuildModelInfo_ArchitectureConversion(t *testing.T) { + tests := []struct { + name string + architectures []string + modelType string + wantArch string + }{ + { + name: "LlamaForCausalLM", + architectures: []string{"LlamaForCausalLM"}, + wantArch: "llama", + }, + { + name: "Gemma3ForCausalLM", + architectures: []string{"Gemma3ForCausalLM"}, + wantArch: "gemma3", + }, + { + name: "Gemma3ForConditionalGeneration", + architectures: []string{"Gemma3ForConditionalGeneration"}, + wantArch: "gemma3", + }, + { + name: "Qwen2ForCausalLM", + architectures: []string{"Qwen2ForCausalLM"}, + wantArch: "qwen2", + }, + { + name: "model_type takes precedence", + architectures: []string{"LlamaForCausalLM"}, + modelType: "custom", + wantArch: "custom", + }, + { + name: "empty architectures with model_type", + architectures: nil, + modelType: "mymodel", + wantArch: "mymodel", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := modelConfig{ + Architectures: tt.architectures, + ModelType: tt.modelType, + } + info := buildModelInfo(config, 0, 0) + + if arch, ok := info["general.architecture"].(string); !ok || arch != tt.wantArch { + t.Errorf("architecture = %v, want %v", info["general.architecture"], tt.wantArch) + } + }) + } +} + +func TestBuildModelInfo_BytesPerParam(t *testing.T) { + tests := []struct { + name string + dtype string + totalBytes int64 + tensorCount int64 + wantParamCount int64 + }{ + { + name: "bfloat16", + dtype: "bfloat16", + totalBytes: 2_000_088, // 1M * 2 + 88 + tensorCount: 1, + wantParamCount: 1_000_000, + }, + { + name: "float16", + dtype: "float16", + totalBytes: 2_000_088, + tensorCount: 1, + wantParamCount: 1_000_000, + }, + { + name: "float32", + dtype: "float32", + totalBytes: 4_000_088, // 1M * 4 + 88 + tensorCount: 1, + wantParamCount: 1_000_000, + }, + { + name: "int8", + dtype: "int8", + totalBytes: 1_000_088, // 1M * 1 + 88 + tensorCount: 1, + wantParamCount: 1_000_000, + }, + { + name: "unknown dtype defaults to 2 bytes", + dtype: "unknown", + totalBytes: 2_000_088, + tensorCount: 1, + wantParamCount: 1_000_000, + }, + { + name: "empty dtype defaults to 2 bytes", + dtype: "", + totalBytes: 2_000_088, + tensorCount: 1, + wantParamCount: 1_000_000, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := modelConfig{ + ModelType: "test", + TorchDtype: tt.dtype, + } + info := buildModelInfo(config, tt.totalBytes, tt.tensorCount) + + if paramCount, ok := info["general.parameter_count"].(int64); !ok || paramCount != tt.wantParamCount { + t.Errorf("parameter_count = %v, want %v", info["general.parameter_count"], tt.wantParamCount) + } + }) + } +} + +func TestParseSafetensorsHeader(t *testing.T) { + tests := []struct { + name string + header map[string]any + wantDtype string + wantShape []int64 + wantErr bool + }{ + { + name: "simple tensor", + header: map[string]any{ + "weight": map[string]any{ + "dtype": "BF16", + "shape": []int64{2560, 262144}, + "data_offsets": []int64{0, 1342177280}, + }, + }, + wantDtype: "BF16", + wantShape: []int64{2560, 262144}, + }, + { + name: "with metadata", + header: map[string]any{ + "__metadata__": map[string]any{ + "format": "pt", + }, + "bias": map[string]any{ + "dtype": "F32", + "shape": []int64{1024}, + "data_offsets": []int64{0, 4096}, + }, + }, + wantDtype: "F32", + wantShape: []int64{1024}, + }, + { + name: "float16 tensor", + header: map[string]any{ + "layer.weight": map[string]any{ + "dtype": "F16", + "shape": []int64{512, 512, 3, 3}, + "data_offsets": []int64{0, 4718592}, + }, + }, + wantDtype: "F16", + wantShape: []int64{512, 512, 3, 3}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create safetensors format: 8-byte size + JSON header + headerJSON, err := json.Marshal(tt.header) + if err != nil { + t.Fatalf("failed to marshal header: %v", err) + } + + var buf bytes.Buffer + if err := binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON))); err != nil { + t.Fatalf("failed to write header size: %v", err) + } + buf.Write(headerJSON) + + info, err := parseSafetensorsHeader(&buf) + if (err != nil) != tt.wantErr { + t.Errorf("parseSafetensorsHeader() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + return + } + + if info.Dtype != tt.wantDtype { + t.Errorf("Dtype = %v, want %v", info.Dtype, tt.wantDtype) + } + + if len(info.Shape) != len(tt.wantShape) { + t.Errorf("Shape length = %v, want %v", len(info.Shape), len(tt.wantShape)) + } else { + for i, s := range info.Shape { + if s != tt.wantShape[i] { + t.Errorf("Shape[%d] = %v, want %v", i, s, tt.wantShape[i]) + } + } + } + }) + } +} + +func TestParseSafetensorsHeader_Errors(t *testing.T) { + tests := []struct { + name string + data []byte + wantErr string + }{ + { + name: "empty data", + data: []byte{}, + wantErr: "failed to read header size", + }, + { + name: "truncated header size", + data: []byte{0x01, 0x02, 0x03}, + wantErr: "failed to read header size", + }, + { + name: "header size too large", + data: func() []byte { + var buf bytes.Buffer + binary.Write(&buf, binary.LittleEndian, uint64(2*1024*1024)) // 2MB + return buf.Bytes() + }(), + wantErr: "header size too large", + }, + { + name: "truncated header", + data: func() []byte { + var buf bytes.Buffer + binary.Write(&buf, binary.LittleEndian, uint64(100)) + buf.Write([]byte("short")) + return buf.Bytes() + }(), + wantErr: "failed to read header", + }, + { + name: "invalid JSON", + data: func() []byte { + var buf bytes.Buffer + binary.Write(&buf, binary.LittleEndian, uint64(10)) + buf.Write([]byte("not json!!")) + return buf.Bytes() + }(), + wantErr: "failed to parse header", + }, + { + name: "no tensors in header", + data: func() []byte { + header := map[string]any{ + "__metadata__": map[string]any{"format": "pt"}, + } + headerJSON, _ := json.Marshal(header) + var buf bytes.Buffer + binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON))) + buf.Write(headerJSON) + return buf.Bytes() + }(), + wantErr: "no tensor found in header", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := parseSafetensorsHeader(bytes.NewReader(tt.data)) + if err == nil { + t.Error("expected error, got nil") + return + } + if !bytes.Contains([]byte(err.Error()), []byte(tt.wantErr)) { + t.Errorf("error = %v, want error containing %v", err, tt.wantErr) + } + }) + } +} + +func TestGetTensorInfoFromManifest(t *testing.T) { + // Create a temp directory for blobs + tempDir := t.TempDir() + + // Create test tensor blobs + tensors := []struct { + name string + digest string + dtype string + shape []int64 + }{ + { + name: "model.embed_tokens.weight", + digest: "sha256:abc123", + dtype: "BF16", + shape: []int64{262144, 2560}, + }, + { + name: "model.layers.0.self_attn.q_proj.weight", + digest: "sha256:def456", + dtype: "BF16", + shape: []int64{2560, 2560}, + }, + { + name: "model.norm.weight", + digest: "sha256:ghi789", + dtype: "F32", + shape: []int64{2560}, + }, + } + + // Create blob files + var layers []imagegen.ManifestLayer + for _, tensor := range tensors { + // Create safetensors blob + header := map[string]any{ + tensor.name: map[string]any{ + "dtype": tensor.dtype, + "shape": tensor.shape, + "data_offsets": []int64{0, 1000}, + }, + } + headerJSON, _ := json.Marshal(header) + + var buf bytes.Buffer + binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON))) + buf.Write(headerJSON) + + // Write blob file + blobName := "sha256-" + tensor.digest[7:] + blobPath := filepath.Join(tempDir, blobName) + if err := os.WriteFile(blobPath, buf.Bytes(), 0o644); err != nil { + t.Fatalf("failed to write blob: %v", err) + } + + layers = append(layers, imagegen.ManifestLayer{ + MediaType: "application/vnd.ollama.image.tensor", + Digest: tensor.digest, + Size: int64(buf.Len() + 1000), // header + fake data + Name: tensor.name, + }) + } + + // Add a non-tensor layer (should be skipped) + layers = append(layers, imagegen.ManifestLayer{ + MediaType: "application/vnd.ollama.image.json", + Digest: "sha256:config", + Size: 100, + Name: "config.json", + }) + + manifest := &imagegen.ModelManifest{ + Manifest: &imagegen.Manifest{ + Layers: layers, + }, + BlobDir: tempDir, + } + + result, err := getTensorInfoFromManifest(manifest) + if err != nil { + t.Fatalf("getTensorInfoFromManifest() error = %v", err) + } + + if len(result) != 3 { + t.Errorf("got %d tensors, want 3", len(result)) + } + + // Verify each tensor + for i, tensor := range tensors { + if i >= len(result) { + break + } + if result[i].Name != tensor.name { + t.Errorf("tensor[%d].Name = %v, want %v", i, result[i].Name, tensor.name) + } + if result[i].Type != tensor.dtype { + t.Errorf("tensor[%d].Type = %v, want %v", i, result[i].Type, tensor.dtype) + } + if len(result[i].Shape) != len(tensor.shape) { + t.Errorf("tensor[%d].Shape length = %v, want %v", i, len(result[i].Shape), len(tensor.shape)) + } + } +} + +func TestReadSafetensorsHeader(t *testing.T) { + // Create a temp file with a valid safetensors header + tempDir := t.TempDir() + + header := map[string]any{ + "test_tensor": map[string]any{ + "dtype": "BF16", + "shape": []int64{1024, 768}, + "data_offsets": []int64{0, 1572864}, + }, + } + headerJSON, _ := json.Marshal(header) + + var buf bytes.Buffer + binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON))) + buf.Write(headerJSON) + + filePath := filepath.Join(tempDir, "test.safetensors") + if err := os.WriteFile(filePath, buf.Bytes(), 0o644); err != nil { + t.Fatalf("failed to write test file: %v", err) + } + + info, err := readSafetensorsHeader(filePath) + if err != nil { + t.Fatalf("readSafetensorsHeader() error = %v", err) + } + + if info.Dtype != "BF16" { + t.Errorf("Dtype = %v, want BF16", info.Dtype) + } + if len(info.Shape) != 2 || info.Shape[0] != 1024 || info.Shape[1] != 768 { + t.Errorf("Shape = %v, want [1024, 768]", info.Shape) + } +} + +func TestReadSafetensorsHeader_FileNotFound(t *testing.T) { + _, err := readSafetensorsHeader("/nonexistent/path/file.safetensors") + if err == nil { + t.Error("expected error for nonexistent file") + } +}