diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index eedd0c61d..433b2ab1b 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -1553,7 +1553,7 @@ func TestShowInfoImageGen(t *testing.T) { Details: api.ModelDetails{ Family: "ZImagePipeline", ParameterSize: "10.3B", - QuantizationLevel: "FP8", + QuantizationLevel: "Q8", }, Capabilities: []model.Capability{model.CapabilityImage}, Requires: "0.14.0", @@ -1565,7 +1565,7 @@ func TestShowInfoImageGen(t *testing.T) { expect := " Model\n" + " architecture ZImagePipeline \n" + " parameters 10.3B \n" + - " quantization FP8 \n" + + " quantization Q8 \n" + " requires 0.14.0 \n" + "\n" + " Capabilities\n" + diff --git a/server/sched.go b/server/sched.go index 2036ca111..3aa9969a0 100644 --- a/server/sched.go +++ b/server/sched.go @@ -589,7 +589,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool { runner := &runnerRef{ model: req.model, - modelPath: req.model.Name, + modelPath: req.model.ModelPath, llama: server, Options: &req.opts, loading: false, @@ -599,7 +599,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool { } s.loadedMu.Lock() - s.loaded[req.model.Name] = runner + s.loaded[req.model.ModelPath] = runner s.loadedMu.Unlock() // Set up expiration timer diff --git a/x/create/client/create.go b/x/create/client/create.go index 58902bcc2..36e7f164b 100644 --- a/x/create/client/create.go +++ b/x/create/client/create.go @@ -35,7 +35,7 @@ type ModelfileConfig struct { type CreateOptions struct { ModelName string ModelDir string - Quantize string // "fp8" for quantization + Quantize string // "q4", "q8", "nvfp4", or "mxfp8" for quantization Modelfile *ModelfileConfig // template/system/license from Modelfile } diff --git a/x/create/client/quantize.go b/x/create/client/quantize.go index 63792025e..e69003f73 100644 --- a/x/create/client/quantize.go +++ b/x/create/client/quantize.go @@ -14,9 +14,9 @@ import ( // quantizeTensor loads a tensor from safetensors format, quantizes it, // and returns safetensors data for the quantized weights, scales, and biases. // Supported quantization types: -// - "fp4": affine 4-bit, group_size=32 (with qbiases) +// - "q4": affine 4-bit, group_size=32 (with qbiases) // - "nvfp4": NVIDIA FP4, group_size=16 (no qbiases, E4M3 scales) -// - "fp8": affine 8-bit, group_size=32 (with qbiases) +// - "q8": affine 8-bit, group_size=64 (with qbiases) // - "mxfp8": Microsoft MX FP8, group_size=32 (no qbiases, E4M3 scales) // Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights). func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) { @@ -58,15 +58,15 @@ func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize str // Quantize based on quantization type var qweight, scales, qbiases *mlx.Array switch quantize { - case "fp4": + case "q4": // affine mode: group_size=32, bits=4 (with qbiases for zero-point offset) qweight, scales, qbiases = mlx.Quantize(arr, 32, 4, "affine") case "nvfp4": // NVIDIA FP4: group_size=16, bits=4 (no qbiases, E4M3 scales) qweight, scales, qbiases = mlx.Quantize(arr, 16, 4, "nvfp4") - case "fp8": - // affine mode: group_size=32, bits=8 (with qbiases for zero-point offset) - qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "affine") + case "q8": + // affine mode: group_size=64, bits=8 (with qbiases for zero-point offset) + qweight, scales, qbiases = mlx.Quantize(arr, 64, 8, "affine") case "mxfp8": // Microsoft MX FP8: group_size=32, bits=8, E4M3 scales (no qbiases) qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "mxfp8") diff --git a/x/create/create.go b/x/create/create.go index 4e18f3c84..2474c8c66 100644 --- a/x/create/create.go +++ b/x/create/create.go @@ -228,7 +228,7 @@ type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error) type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error) // QuantizingTensorLayerCreator creates tensor layers with optional quantization. -// When quantize is non-empty (e.g., "fp8"), returns multiple layers (weight + scales + biases). +// When quantize is non-empty (e.g., "q8"), returns multiple layers (weight + scales + biases). type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) // ManifestWriter writes the manifest file. @@ -264,40 +264,132 @@ func ShouldQuantize(name, component string) bool { // ShouldQuantizeTensor returns true if a tensor should be quantized based on name, shape, and quantize type. // This is a more detailed check that also considers tensor dimensions. -// The quantize parameter specifies the quantization type (e.g., "fp4", "nvfp4", "fp8", "mxfp8"). +// The quantize parameter specifies the quantization type (e.g., "q4", "nvfp4", "q8", "mxfp8"). func ShouldQuantizeTensor(name string, shape []int32, quantize string) bool { + return GetTensorQuantization(name, shape, quantize) != "" +} + +// normalizeQuantType converts various quantization type aliases to canonical forms. +// Supports: q4/Q4/int4/INT4/fp4/FP4 -> q4, q8/Q8/int8/INT8/fp8/FP8 -> q8, nvfp4/NVFP4, mxfp8/MXFP8 +func normalizeQuantType(quantize string) string { + switch strings.ToUpper(quantize) { + case "Q4", "INT4", "FP4": + return "q4" + case "Q8", "INT8", "FP8": + return "q8" + case "NVFP4": + return "nvfp4" + case "MXFP8": + return "mxfp8" + default: + return quantize + } +} + +// getQuantGroupSize returns the group size for a given quantization type. +// These must match the values used in quantize.go when creating quantized models. +func getQuantGroupSize(quantize string) int { + switch normalizeQuantType(quantize) { + case "nvfp4": + return 16 + case "q4": + return 32 + case "mxfp8": + return 32 + case "q8": + return 64 + default: + return 32 + } +} + +// GetTensorQuantization returns the appropriate quantization type for a tensor. +// Returns "" if the tensor should not be quantized. +// This implements mixed-precision quantization: +// - Attention MLA weights (q_a, q_b, kv_a, kv_b): unquantized (most sensitive) +// - Output projection, gate/up weights: q4 (less sensitive) +// - Down projection weights: q8 (more sensitive, would be Q6 in GGML but no MLX kernel) +// - Norms, embeddings, biases, routing gates: no quantization +func GetTensorQuantization(name string, shape []int32, quantize string) string { // Use basic name-based check first if !ShouldQuantize(name, "") { - return false + return "" } // Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any) if len(shape) != 2 { - return false + return "" } // Skip small tensors (less than 1024 elements) - not worth quantizing if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 { - return false + return "" } + // Normalize quantization type to canonical form + quantNorm := normalizeQuantType(quantize) + // MLX quantization requires last dimension to be divisible by group size - // NVFP4 uses group_size=16, all other modes (fp4, fp8, mxfp8) use group_size=32 + // nvfp4: 16, q4/mxfp8: 32, q8: 64 groupSize := int32(32) - if strings.ToUpper(quantize) == "NVFP4" { + switch quantNorm { + case "nvfp4": groupSize = 16 + case "q8": + groupSize = 64 } if shape[len(shape)-1]%groupSize != 0 { - return false + return "" } - return true + // Skip routing gate weights (should stay high precision) + // In safetensors these are: mlp.gate.weight (not mlp.gate_proj.weight) + if strings.Contains(name, "mlp.gate.weight") && !strings.Contains(name, "_proj") { + return "" + } + + // For NVFP4 or MXFP8, use the same quantization for all (no mixed precision) + if quantNorm == "nvfp4" || quantNorm == "mxfp8" { + return quantNorm + } + + // Attention MLA weights - keep unquantized (bf16) + // These are highly sensitive: errors accumulate in the KV cache over time + // q_a_proj, q_b_proj, kv_a_proj_with_mqa, kv_b_proj + if strings.Contains(name, "q_a_proj") || + strings.Contains(name, "q_b_proj") || + strings.Contains(name, "kv_a_proj") || + strings.Contains(name, "kv_b_proj") { + return "" // No quantization - keep bf16 + } + + // Down projection weights - use Q8 (would be Q6_K in GGML, but MLX has no Q6 kernel) + // mlp.down_proj, mlp.experts.X.down_proj, mlp.shared_experts.down_proj + if strings.Contains(name, "down_proj") { + return "q8" + } + + // Output projection, gate/up weights - use requested quantization (Q4) + // o_proj, gate_proj, up_proj + if strings.Contains(name, "o_proj") || + strings.Contains(name, "gate_proj") || + strings.Contains(name, "up_proj") { + return quantNorm + } + + // LM head - use requested quantization + if strings.Contains(name, "lm_head") { + return quantNorm + } + + // Default to requested quantization for other weights + return quantNorm } // CreateSafetensorsModel imports a standard safetensors model from a directory. // This handles Hugging Face style models with config.json and *.safetensors files. // Stores each tensor as a separate blob for fine-grained deduplication. -// If quantize is non-empty (e.g., "fp8"), eligible tensors will be quantized. +// If quantize is non-empty (e.g., "q8"), eligible tensors will be quantized. func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error { var layers []LayerInfo var configLayer LayerInfo @@ -336,9 +428,10 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La } // Determine quantization type for this tensor (empty string if not quantizing) + // GetTensorQuantization handles mixed-precision (e.g., Q8 for attention, Q4 for FFN) quantizeType := "" - if quantize != "" && ShouldQuantizeTensor(tensorName, td.Shape, quantize) { - quantizeType = quantize + if quantize != "" { + quantizeType = GetTensorQuantization(tensorName, td.Shape, quantize) } // Store as minimal safetensors format (88 bytes header overhead) @@ -398,6 +491,7 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La if quantize != "" { modelIndex := map[string]any{ "quantization": strings.ToUpper(quantize), + "group_size": getQuantGroupSize(quantize), } indexData, err := json.MarshalIndent(modelIndex, "", " ") if err != nil { diff --git a/x/create/imagegen.go b/x/create/imagegen.go index d5f775e2e..0da0e764a 100644 --- a/x/create/imagegen.go +++ b/x/create/imagegen.go @@ -15,15 +15,15 @@ import ( // CreateImageGenModel imports an image generation model from a directory. // Stores each tensor as a separate blob for fine-grained deduplication. // If quantize is specified, linear weights in transformer/text_encoder are quantized. -// Supported quantization types: fp4, fp8, nvfp4, mxfp8 (or empty for no quantization). +// Supported quantization types: q4, q8, nvfp4, mxfp8 (or empty for no quantization). // Layer creation and manifest writing are done via callbacks to avoid import cycles. func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error { // Validate quantization type switch quantize { - case "", "fp4", "fp8", "nvfp4", "mxfp8": + case "", "q4", "q8", "nvfp4", "mxfp8": // valid default: - return fmt.Errorf("unsupported quantization type %q: supported types are fp4, fp8, nvfp4, mxfp8", quantize) + return fmt.Errorf("unsupported quantization type %q: supported types are q4, q8, nvfp4, mxfp8", quantize) } var layers []LayerInfo @@ -214,14 +214,17 @@ func CreateImageGenModel(modelName, modelDir, quantize string, createLayer Layer // canQuantizeShape returns true if a tensor shape is compatible with MLX quantization. // MLX requires the last dimension to be divisible by the group size. -// NVFP4 uses group_size=16, all other modes use group_size=32. +// nvfp4: 16, q4/mxfp8: 32, q8: 64 func canQuantizeShape(shape []int32, quantize string) bool { if len(shape) < 2 { return false } groupSize := int32(32) - if strings.ToUpper(quantize) == "NVFP4" { + switch strings.ToUpper(quantize) { + case "NVFP4": groupSize = 16 + case "Q8": + groupSize = 64 } return shape[len(shape)-1]%groupSize == 0 } diff --git a/x/imagegen/cmd/engine/main.go b/x/imagegen/cmd/engine/main.go index ae924a127..f0e705d1c 100644 --- a/x/imagegen/cmd/engine/main.go +++ b/x/imagegen/cmd/engine/main.go @@ -19,7 +19,6 @@ import ( "github.com/ollama/ollama/x/imagegen/mlx" "github.com/ollama/ollama/x/imagegen/models/flux2" "github.com/ollama/ollama/x/imagegen/models/gemma3" - "github.com/ollama/ollama/x/imagegen/models/glm4_moe_lite" "github.com/ollama/ollama/x/imagegen/models/gpt_oss" "github.com/ollama/ollama/x/imagegen/models/llama" "github.com/ollama/ollama/x/imagegen/models/zimage" @@ -243,8 +242,6 @@ func load(modelPath string) (Model, error) { return gemma3.Load(modelPath) case "gemma3_text": return gemma3.LoadText(modelPath) - case "glm4_moe_lite": - return glm4_moe_lite.Load(modelPath) default: return llama.Load(modelPath) } diff --git a/x/imagegen/manifest.go b/x/imagegen/manifest.go index e71c89c2c..da55cbe81 100644 --- a/x/imagegen/manifest.go +++ b/x/imagegen/manifest.go @@ -209,7 +209,7 @@ func GetModelInfo(modelName string) (*ModelInfo, error) { if info.Quantization == "" { for _, layer := range manifest.Manifest.Layers { if strings.HasSuffix(layer.Name, ".weight_scale") { - info.Quantization = "FP8" + info.Quantization = "Q8" break } } diff --git a/x/imagegen/models/glm4_moe_lite/glm4_moe_lite.go b/x/imagegen/models/glm4_moe_lite/glm4_moe_lite.go index 10cbebbdd..caebbe361 100644 --- a/x/imagegen/models/glm4_moe_lite/glm4_moe_lite.go +++ b/x/imagegen/models/glm4_moe_lite/glm4_moe_lite.go @@ -8,8 +8,6 @@ import ( "encoding/json" "fmt" "math" - "os" - "path/filepath" "github.com/ollama/ollama/x/imagegen" "github.com/ollama/ollama/x/imagegen/cache" @@ -59,6 +57,11 @@ type Config struct { // RoPE scaling RopeScaling *RopeScaling `json:"rope_scaling"` + // Quantization parameters (set during load based on model quantization) + QuantGroupSize int `json:"-"` // Group size for quantization (default 64) + QuantBits int `json:"-"` // Bits per weight (4 or 8) + QuantMode string `json:"-"` // Quantization mode ("affine", etc.) + // Computed fields QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim Scale float32 `json:"-"` // 1/sqrt(QHeadDim) with mscale adjustment @@ -236,9 +239,28 @@ func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) { // SwitchMLP implements the MoE expert computation using stacked weights // Note: No weight tags - these are populated manually by stacking expert weights type SwitchMLP struct { + // Dequantized weights (used when GatherQMM not available) GateWeight *mlx.Array UpWeight *mlx.Array DownWeight *mlx.Array + + // Quantized weights (used with GatherQMM for 4/8-bit affine) + GateWeightQ, GateScales, GateBiases *mlx.Array + UpWeightQ, UpScales, UpBiases *mlx.Array + DownWeightQ, DownScales, DownBiases *mlx.Array + + // Quantization bits per projection (supports mixed precision Q4/Q8) + GateBits int + UpBits int + DownBits int + + // Quantization group size per projection (detected from tensor shapes) + GateGroupSize int + UpGroupSize int + DownGroupSize int + + // If true, use GatherQMM with quantized weights + UseQuantized bool } // Forward applies the switched expert MLP @@ -270,17 +292,29 @@ func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx. idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1) } - // Expert computation using gather_mm - // gate: x @ gate_weight.T (indices are on the rhs/weight side) - gate := mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort) - // up: x @ up_weight.T - up := mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort) + var gate, up, hidden, down *mlx.Array - // SwiGLU activation - hidden := mlx.Mul(mlx.SiLU(gate), up) + if s.UseQuantized { + // Use GatherQMM for quantized weights (faster, keeps weights quantized) + // Each projection may have different bits and group sizes (mixed precision: Q4 for gate/up, Q8 for down) + gate = mlx.GatherQMM(xFlat, s.GateWeightQ, s.GateScales, s.GateBiases, + nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort) + up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases, + nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort) - // down: hidden @ down_weight.T - down := mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort) + hidden = mlx.Mul(mlx.SiLU(gate), up) + + down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases, + nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort) + } else { + // Use GatherMM for dequantized/non-quantized weights + gate = mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort) + up = mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort) + + hidden = mlx.Mul(mlx.SiLU(gate), up) + + down = mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort) + } // Unsort if we sorted if doSort { @@ -402,9 +436,36 @@ func computeScale(cfg *Config) float32 { return scale } -// loadExpertWeight loads an expert weight, dequantizing if necessary. -// GatherMM doesn't support quantized weights, so we must dequantize for MoE. -func loadExpertWeight(weights safetensors.WeightSource, path string) *mlx.Array { +// supportsGatherQMM returns true if the quantization mode has GatherQMM kernel support. +// Currently only 4-bit and 8-bit affine quantization are supported. +func supportsGatherQMM(mode string, bits int) bool { + return mode == "affine" && (bits == 4 || bits == 8) +} + +// ExpertWeight holds a single expert's weight with optional quantization components. +type ExpertWeight struct { + Weight *mlx.Array // Quantized weight (if quantized) or dequantized weight + Scales *mlx.Array // Quantization scales (nil if not quantized) + Biases *mlx.Array // Quantization biases (nil if not quantized or mode doesn't use biases) + Bits int // Quantization bits (4 or 8), 0 if not quantized + GroupSize int // Quantization group size, 0 if not quantized +} + +// getQuantParams returns quantization parameters from model metadata. +// Returns groupSize, bits, and mode for the model's quantization type. +func getQuantParams(weights safetensors.WeightSource) (groupSize, bits int, mode string) { + groupSize, bits, mode = safetensors.QuantizationParams(weights.Quantization()) + // Use metadata group_size if available (overrides default) + if gs := weights.GroupSize(); gs > 0 { + groupSize = gs + } + return groupSize, bits, mode +} + +// loadExpertWeight loads an expert weight. +// If useQuantized is true and the weight is quantized with a supported mode, returns quantized components. +// Otherwise dequantizes and returns only the weight. +func loadExpertWeight(weights safetensors.WeightSource, path string, useQuantized bool, cfg *Config) *ExpertWeight { w, _ := weights.GetTensor(path + ".weight") if w == nil { return nil @@ -419,12 +480,25 @@ func loadExpertWeight(weights safetensors.WeightSource, path string) *mlx.Array if weights.HasTensor(qbiasPath) { qbiases, _ = weights.GetTensor(qbiasPath) } - // Dequantize using the model's quantization parameters - groupSize, bits, mode := safetensors.QuantizationParams(weights.Quantization()) - return mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode) + + // Get quantization params from metadata + groupSize, bits, mode := getQuantParams(weights) + + // Update config with group size (for GatherQMM calls) + if cfg.QuantGroupSize == 0 { + cfg.QuantGroupSize = groupSize + } + + // If GatherQMM is supported and requested, return quantized components + if useQuantized && supportsGatherQMM(mode, bits) { + return &ExpertWeight{Weight: w, Scales: scales, Biases: qbiases, Bits: bits, GroupSize: groupSize} + } + + // Otherwise dequantize + return &ExpertWeight{Weight: mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)} } - return w + return &ExpertWeight{Weight: w} } // sanitizeMLAWeights transforms kv_b_proj weights into absorbed MLA format. @@ -450,7 +524,8 @@ func sanitizeMLAWeights(weights safetensors.WeightSource, prefix string, cfg *Co if weights.HasTensor(qbiasPath) { qbiases, _ = weights.GetTensor(qbiasPath) } - groupSize, bits, mode := safetensors.QuantizationParams(weights.Quantization()) + + groupSize, bits, mode := getQuantParams(weights) w = mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode) } @@ -477,140 +552,68 @@ func sanitizeMLAWeights(weights safetensors.WeightSource, prefix string, cfg *Co return embedQ, unembedOut } -// sanitizeExpertWeights stacks individual expert weights into a single tensor. -// For quantized models, expert weights are dequantized since GatherMM doesn't support quantized weights. -func sanitizeExpertWeights(weights safetensors.WeightSource, prefix string, numExperts int32) (*mlx.Array, *mlx.Array, *mlx.Array) { - var gateWeights, upWeights, downWeights []*mlx.Array - - for e := int32(0); e < numExperts; e++ { - gw := loadExpertWeight(weights, fmt.Sprintf("%s.mlp.experts.%d.gate_proj", prefix, e)) - uw := loadExpertWeight(weights, fmt.Sprintf("%s.mlp.experts.%d.up_proj", prefix, e)) - dw := loadExpertWeight(weights, fmt.Sprintf("%s.mlp.experts.%d.down_proj", prefix, e)) - - if gw != nil { - gateWeights = append(gateWeights, gw) - } - if uw != nil { - upWeights = append(upWeights, uw) - } - if dw != nil { - downWeights = append(downWeights, dw) - } - } - - var stackedGate, stackedUp, stackedDown *mlx.Array - if len(gateWeights) > 0 { - stackedGate = mlx.Stack(gateWeights, 0) - } - if len(upWeights) > 0 { - stackedUp = mlx.Stack(upWeights, 0) - } - if len(downWeights) > 0 { - stackedDown = mlx.Stack(downWeights, 0) - } - - return stackedGate, stackedUp, stackedDown +// StackedExpertWeights holds stacked weights for all experts. +type StackedExpertWeights struct { + Weight *mlx.Array // Stacked weights [num_experts, out, in] or [num_experts, out, in_packed] + Scales *mlx.Array // Stacked scales (nil if not quantized) + Biases *mlx.Array // Stacked biases (nil if not quantized) + Bits int // Quantization bits (4 or 8), 0 if not quantized + GroupSize int // Quantization group size, 0 if not quantized } -// Load loads a GLM4-MoE-Lite model from the given path -func Load(modelPath string) (*Model, error) { - data, err := os.ReadFile(filepath.Join(modelPath, "config.json")) - if err != nil { - return nil, fmt.Errorf("load config: %w", err) - } +// collectAndStackExpertWeights loads and stacks expert weights for one projection type. +func collectAndStackExpertWeights( + weights safetensors.WeightSource, + prefix string, + projName string, + numExperts int32, + useQuantized bool, + cfg *Config, +) *StackedExpertWeights { + var w, s, b []*mlx.Array + var bits, groupSize int - var cfg Config - if err := json.Unmarshal(data, &cfg); err != nil { - return nil, fmt.Errorf("parse config: %w", err) - } - - // Compute derived fields - cfg.QHeadDim = cfg.QKNopeHeadDim + cfg.QKRopeHeadDim - cfg.Scale = computeScale(&cfg) - - weights, err := safetensors.LoadModelWeights(modelPath) - if err != nil { - return nil, fmt.Errorf("load weights: %w", err) - } - - tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json")) - if err != nil { - return nil, fmt.Errorf("load tokenizer: %w", err) - } - - m := &Model{ - Layers: make([]Block, cfg.NumHiddenLayers), - Config: &cfg, - tok: tok, - } - - // Load embedding, norm, and lm_head - if err := safetensors.LoadModule(m, weights, ""); err != nil { - return nil, err - } - - // Load layers manually due to different block types - for i := int32(0); i < cfg.NumHiddenLayers; i++ { - prefix := fmt.Sprintf("model.layers.%d", i) - - // Load attention (same for both block types) - attn := &MLAAttention{} - if err := safetensors.LoadModule(attn, weights, prefix); err != nil { - return nil, fmt.Errorf("layer %d attention: %w", i, err) + for e := int32(0); e < numExperts; e++ { + path := fmt.Sprintf("%s.mlp.experts.%d.%s", prefix, e, projName) + ew := loadExpertWeight(weights, path, useQuantized, cfg) + if ew == nil { + continue } - - // Sanitize MLA weights for absorbed attention - embedQ, unembedOut := sanitizeMLAWeights(weights, prefix, &cfg) - attn.EmbedQ = nn.NewMultiLinear(embedQ) - attn.UnembedOut = nn.NewMultiLinear(unembedOut) - - if i < cfg.FirstKDenseReplace { - // Dense block - block := &DenseBlock{Attention: attn} - if err := safetensors.LoadModule(block, weights, prefix); err != nil { - return nil, fmt.Errorf("layer %d dense: %w", i, err) - } - m.Layers[i] = block - } else { - // MoE block - block := &MoEBlock{Attention: attn} - if err := safetensors.LoadModule(block, weights, prefix); err != nil { - return nil, fmt.Errorf("layer %d moe block: %w", i, err) - } - - // Stack expert weights - gateW, upW, downW := sanitizeExpertWeights(weights, prefix, cfg.NRoutedExperts) - - block.MoE = &MoE{ - Gate: &MoEGate{}, - SwitchMLP: &SwitchMLP{ - GateWeight: gateW, - UpWeight: upW, - DownWeight: downW, - }, - } - - // Load gate weights - if err := safetensors.LoadModule(block.MoE.Gate, weights, prefix); err != nil { - return nil, fmt.Errorf("layer %d gate: %w", i, err) - } - - // Load shared experts if present - if cfg.NSharedExperts > 0 { - block.MoE.SharedExperts = &SharedExperts{} - if err := safetensors.LoadModule(block.MoE.SharedExperts, weights, prefix); err != nil { - return nil, fmt.Errorf("layer %d shared experts: %w", i, err) - } - } - - m.Layers[i] = block + w = append(w, ew.Weight) + if ew.Scales != nil { + s = append(s, ew.Scales) + } + if ew.Biases != nil { + b = append(b, ew.Biases) + } + if e == 0 { + bits = ew.Bits + groupSize = ew.GroupSize } } - mlx.Eval(mlx.Collect(m)...) - weights.ReleaseAll() + result := &StackedExpertWeights{Bits: bits, GroupSize: groupSize} + if len(w) > 0 { + result.Weight = mlx.Stack(w, 0) + if len(s) > 0 { + result.Scales = mlx.Stack(s, 0) + } + if len(b) > 0 { + result.Biases = mlx.Stack(b, 0) + } + } + return result +} - return m, nil +// sanitizeExpertWeights stacks individual expert weights into tensors. +// If useQuantized is true and weights support GatherQMM, returns quantized components. +// Otherwise returns dequantized weights with nil scales/biases. +// Bits and GroupSize are detected per-weight to support mixed-precision (Q4 for gate/up, Q8 for down). +func sanitizeExpertWeights(weights safetensors.WeightSource, prefix string, numExperts int32, useQuantized bool, cfg *Config) (gate, up, down *StackedExpertWeights) { + gate = collectAndStackExpertWeights(weights, prefix, "gate_proj", numExperts, useQuantized, cfg) + up = collectAndStackExpertWeights(weights, prefix, "up_proj", numExperts, useQuantized, cfg) + down = collectAndStackExpertWeights(weights, prefix, "down_proj", numExperts, useQuantized, cfg) + return gate, up, down } // LoadFromManifest loads a GLM4-MoE-Lite model from a manifest (Ollama blob storage). @@ -636,19 +639,19 @@ func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) { return nil, fmt.Errorf("load weights: %w", err) } - // Debug: print quantization info and sample tensor names - fmt.Printf("GLM4: quantization=%q, num_tensors=%d\n", weights.Quantization(), len(weights.ListTensors())) - tensors := weights.ListTensors() - for i, name := range tensors { - if i < 20 { // Print first 20 tensor names - fmt.Printf(" tensor[%d]: %s\n", i, name) - } - } - if err := weights.Load(0); err != nil { return nil, fmt.Errorf("load weight data: %w", err) } + // Set up quantization parameters (only if model is actually quantized) + // Note: QuantGroupSize will be detected dynamically from tensor shapes during weight loading + quantization := weights.Quantization() + useQuantized := false + if quantization != "" { + _, cfg.QuantBits, cfg.QuantMode = safetensors.QuantizationParams(quantization) + useQuantized = supportsGatherQMM(cfg.QuantMode, cfg.QuantBits) + } + // Load tokenizer from manifest with config files for EOS token detection tokData, err := manifest.ReadConfig("tokenizer.json") if err != nil { @@ -715,16 +718,35 @@ func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) { return nil, fmt.Errorf("layer %d moe block: %w", i, err) } - // Stack expert weights - gateW, upW, downW := sanitizeExpertWeights(weights, prefix, cfg.NRoutedExperts) + // Stack expert weights (pass cfg so group sizes can be detected) + gate, up, down := sanitizeExpertWeights(weights, prefix, cfg.NRoutedExperts, useQuantized, &cfg) + + switchMLP := &SwitchMLP{UseQuantized: useQuantized} + if useQuantized { + switchMLP.GateWeightQ = gate.Weight + switchMLP.GateScales = gate.Scales + switchMLP.GateBiases = gate.Biases + switchMLP.GateBits = gate.Bits + switchMLP.GateGroupSize = gate.GroupSize + switchMLP.UpWeightQ = up.Weight + switchMLP.UpScales = up.Scales + switchMLP.UpBiases = up.Biases + switchMLP.UpBits = up.Bits + switchMLP.UpGroupSize = up.GroupSize + switchMLP.DownWeightQ = down.Weight + switchMLP.DownScales = down.Scales + switchMLP.DownBiases = down.Biases + switchMLP.DownBits = down.Bits + switchMLP.DownGroupSize = down.GroupSize + } else { + switchMLP.GateWeight = gate.Weight + switchMLP.UpWeight = up.Weight + switchMLP.DownWeight = down.Weight + } block.MoE = &MoE{ - Gate: &MoEGate{}, - SwitchMLP: &SwitchMLP{ - GateWeight: gateW, - UpWeight: upW, - DownWeight: downW, - }, + Gate: &MoEGate{}, + SwitchMLP: switchMLP, } // Load gate weights diff --git a/x/imagegen/safetensors/loader.go b/x/imagegen/safetensors/loader.go index 899ddc3a8..4c1d0a9af 100644 --- a/x/imagegen/safetensors/loader.go +++ b/x/imagegen/safetensors/loader.go @@ -17,7 +17,8 @@ type WeightSource interface { GetTensor(name string) (*mlx.Array, error) ListTensors() []string HasTensor(name string) bool - Quantization() string // Returns "NVFP4", "FP4", "FP8", or "" + Quantization() string // Returns "NVFP4", "Q4", "Q8", or "" + GroupSize() int // Returns quantization group size, or 0 if not specified } // QuantizationParams returns groupSize, bits, mode for a quantization type. @@ -38,7 +39,7 @@ func QuantizationParams(quantization string) (groupSize, bits int, mode string) return 32, 8, "mxfp8" case "FP8", "Q8", "INT8", "": // 8-bit quantization with affine mode (default for quantized models) - return 32, 8, "affine" + return 64, 8, "affine" default: return 32, 8, "affine" // Default to affine } @@ -273,8 +274,48 @@ func LoadMultiLinearLayer(weights WeightSource, path string) (nn.MultiLinearLaye } // Always dequantize for MultiLinear - no batched quantized matmul support - groupSize, bits, mode := QuantizationParams(weights.Quantization()) - weight = mlx.Dequantize(weight, scales, qbiases, groupSize, bits, mode) + // Detect bits from tensor shapes (supports mixed-precision Q4/Q8) + weightShape := weight.Shape() + scalesShape := scales.Shape() + weightCols := int(weightShape[len(weightShape)-1]) + scalesCols := int(scalesShape[len(scalesShape)-1]) + + // Detect quantization from tensor shapes + // groupSize = weightCols * packFactor / scalesCols + // Note: groupSize4 = 2 * groupSize8 always, so ambiguous cases need metadata + groupSize4 := weightCols * 8 / scalesCols + groupSize8 := weightCols * 4 / scalesCols + + var bits, groupSize int + // Use metadata to help disambiguate when shapes are ambiguous + // (e.g., Q4 with group_size=64 has same shapes as Q8 with group_size=32) + quantType := strings.ToUpper(weights.Quantization()) + isQ8Type := quantType == "Q8" || quantType == "FP8" || quantType == "INT8" + + if groupSize4 == 32 { + // Unambiguous: Q4 with group_size=32 + bits = 4 + groupSize = 32 + } else if groupSize8 == 64 { + // Unambiguous: Q8 with group_size=64 + bits = 8 + groupSize = 64 + } else if groupSize4 == 64 && groupSize8 == 32 { + // Ambiguous: could be Q4/gs=64 or Q8/gs=32, use metadata + if isQ8Type { + bits = 8 + groupSize = 32 + } else { + bits = 4 + groupSize = 64 + } + } else { + // Fallback: use global quantization params + _, bits, _ = QuantizationParams(weights.Quantization()) + packFactor := 32 / bits + groupSize = weightCols * packFactor / scalesCols + } + weight = mlx.Dequantize(weight, scales, qbiases, groupSize, bits, "affine") } return nn.NewMultiLinear(weight), nil @@ -310,7 +351,48 @@ func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error) qbiases, _ = weights.GetTensor(qbiasPath) } - groupSize, bits, mode := QuantizationParams(weights.Quantization()) + // Detect bits from tensor shapes (supports mixed-precision Q4/Q8) + weightShape := weight.Shape() + scalesShape := scales.Shape() + weightCols := int(weightShape[len(weightShape)-1]) + scalesCols := int(scalesShape[len(scalesShape)-1]) + + // Detect quantization from tensor shapes + // groupSize = weightCols * packFactor / scalesCols + // Note: groupSize4 = 2 * groupSize8 always, so ambiguous cases need metadata + groupSize4 := weightCols * 8 / scalesCols + groupSize8 := weightCols * 4 / scalesCols + + var bits, groupSize int + mode := "affine" + // Use metadata to help disambiguate when shapes are ambiguous + // (e.g., Q4 with group_size=64 has same shapes as Q8 with group_size=32) + quantType := strings.ToUpper(weights.Quantization()) + isQ8Type := quantType == "Q8" || quantType == "FP8" || quantType == "INT8" + + if groupSize4 == 32 { + // Unambiguous: Q4 with group_size=32 + bits = 4 + groupSize = 32 + } else if groupSize8 == 64 { + // Unambiguous: Q8 with group_size=64 + bits = 8 + groupSize = 64 + } else if groupSize4 == 64 && groupSize8 == 32 { + // Ambiguous: could be Q4/gs=64 or Q8/gs=32, use metadata + if isQ8Type { + bits = 8 + groupSize = 32 + } else { + bits = 4 + groupSize = 64 + } + } else { + // Fallback: use global quantization params + _, bits, mode = QuantizationParams(weights.Quantization()) + packFactor := 32 / bits + groupSize = weightCols * packFactor / scalesCols + } // NVFP4 and MXFP8 don't have native quantized matmul kernels in MLX, // so we always dequantize at load time. Affine modes (FP4, FP8) have kernel support. diff --git a/x/imagegen/safetensors/safetensors.go b/x/imagegen/safetensors/safetensors.go index a36052fce..4dbcf59a3 100644 --- a/x/imagegen/safetensors/safetensors.go +++ b/x/imagegen/safetensors/safetensors.go @@ -303,6 +303,11 @@ func (mw *ModelWeights) Quantization() string { return "" } +// GroupSize returns 0 for directory-based weights (use default). +func (mw *ModelWeights) GroupSize() int { + return 0 +} + // ReleaseAll releases all cached native file handles. func (mw *ModelWeights) ReleaseAll() { for path, native := range mw.nativeCache { diff --git a/x/imagegen/weights.go b/x/imagegen/weights.go index 470c11fdf..eb60c9895 100644 --- a/x/imagegen/weights.go +++ b/x/imagegen/weights.go @@ -209,19 +209,36 @@ func (mw *ManifestWeights) Quantization() string { // So scale size should be ~weight_size * 4 / 32 = weight_size / 8 // Rough size heuristic (assuming float16 scales) - // FP4: scale_bytes ≈ weight_bytes / 4 * 2 / 4 = weight_bytes / 8 - // FP8: scale_bytes ≈ weight_bytes / 8 * 2 / 4 = weight_bytes / 16 + // Q4: scale_bytes ≈ weight_bytes / 4 * 2 / 4 = weight_bytes / 8 + // Q8: scale_bytes ≈ weight_bytes / 8 * 2 / 4 = weight_bytes / 16 ratio := float64(weightLayer.Size) / float64(scaleLayer.Size) if ratio < 12 { - // Closer to 8 = FP4 - return "FP4" + // Closer to 8 = Q4 + return "Q4" } - // Closer to 16 = FP8 - return "FP8" + // Closer to 16 = Q8 + return "Q8" } - // Default to FP4 for affine mode (most common) - return "FP4" + // Default to Q4 for affine mode (most common) + return "Q4" +} + +// GroupSize returns the quantization group size from model_index.json. +// Returns 0 if not specified (caller should use default based on quantization type). +func (mw *ManifestWeights) GroupSize() int { + if mw.manifest == nil { + return 0 + } + + var index struct { + GroupSize int `json:"group_size"` + } + if err := mw.manifest.ReadConfigJSON("model_index.json", &index); err == nil && index.GroupSize > 0 { + return index.GroupSize + } + + return 0 } // ReleaseAll frees all native handles and clears the tensor cache. diff --git a/x/server/show.go b/x/server/show.go index 7158418fb..652293e77 100644 --- a/x/server/show.go +++ b/x/server/show.go @@ -163,9 +163,18 @@ func GetSafetensorsTensorInfo(name model.Name) ([]api.Tensor, error) { // getTensorInfoFromManifest extracts tensor info from a manifest. // This is separated for testability. +// For quantized models, groups weight/scale/qbias into single entries with detected quantization type. func getTensorInfoFromManifest(mf *manifest.Manifest) ([]api.Tensor, error) { var tensors []api.Tensor + // First pass: collect all tensor info and identify scale tensors + type tensorData struct { + info *safetensorsTensorInfo + digest string + } + tensorMap := make(map[string]*tensorData) + scaleMap := make(map[string]*tensorData) // base name -> scale tensor info + for _, layer := range mf.Layers { if layer.MediaType != manifest.MediaTypeImageTensor { continue @@ -178,21 +187,89 @@ func getTensorInfoFromManifest(mf *manifest.Manifest) ([]api.Tensor, error) { } info, err := readSafetensorsHeader(blobPath) if err != nil { - // Skip tensors we can't read continue } - // Convert shape from int to uint64 - shape := make([]uint64, len(info.Shape)) - for i, s := range info.Shape { - shape[i] = uint64(s) + td := &tensorData{info: info, digest: layer.Digest} + + if strings.HasSuffix(layer.Name, "_scale") { + baseName := strings.TrimSuffix(layer.Name, "_scale") + scaleMap[baseName] = td + } else if strings.HasSuffix(layer.Name, "_qbias") { + // Skip qbias tensors - they're included with the quantized weight + continue + } else { + tensorMap[layer.Name] = td + } + } + + // Second pass: build tensor list with quantization info + for _, layer := range mf.Layers { + if layer.MediaType != manifest.MediaTypeImageTensor { + continue } - tensors = append(tensors, api.Tensor{ - Name: layer.Name, - Type: info.Dtype, - Shape: shape, - }) + // Skip scale and qbias tensors + if strings.HasSuffix(layer.Name, "_scale") || strings.HasSuffix(layer.Name, "_qbias") { + continue + } + + td := tensorMap[layer.Name] + if td == nil { + continue + } + + // Check if this tensor has a corresponding scale tensor (quantized) + scaleTd := scaleMap[layer.Name] + if scaleTd != nil && len(td.info.Shape) >= 2 && len(scaleTd.info.Shape) >= 2 { + // Quantized tensor - detect bits from shapes + weightCols := td.info.Shape[len(td.info.Shape)-1] + scaleCols := scaleTd.info.Shape[len(scaleTd.info.Shape)-1] + + // Detect quantization: Q4 has pack_factor=8, Q8 has pack_factor=4 + // Q4 uses group_size=32: weightCols * 8 / scaleCols = 32 + // Q8 uses group_size=64: weightCols * 4 / scaleCols = 64 + var bits int + var quantType string + if weightCols*8/scaleCols == 32 { + bits = 4 + quantType = "Q4" + } else if weightCols*4/scaleCols == 64 { + bits = 8 + quantType = "Q8" + } else { + // Unknown quantization, show raw + quantType = td.info.Dtype + } + + // Calculate unpacked shape + shape := make([]uint64, len(td.info.Shape)) + for i, s := range td.info.Shape { + shape[i] = uint64(s) + } + if bits > 0 { + packFactor := int64(32 / bits) + shape[len(shape)-1] = uint64(td.info.Shape[len(td.info.Shape)-1] * packFactor) + } + + tensors = append(tensors, api.Tensor{ + Name: layer.Name, + Type: quantType, + Shape: shape, + }) + } else { + // Non-quantized tensor + shape := make([]uint64, len(td.info.Shape)) + for i, s := range td.info.Shape { + shape[i] = uint64(s) + } + + tensors = append(tensors, api.Tensor{ + Name: layer.Name, + Type: td.info.Dtype, + Shape: shape, + }) + } } return tensors, nil @@ -231,9 +308,9 @@ func GetSafetensorsDtype(name model.Name) (string, error) { if hasScales { if hasQBias { - // Affine mode (has scale + qbias) - could be FP4 or FP8 - // Default to FP4 as it's more common - return "FP4", nil + // Affine mode (has scale + qbias) - could be Q4 or Q8 + // Default to Q4 as it's more common + return "Q4", nil } // No qbias = NVFP4 return "NVFP4", nil