fix quantizations for Q4 and Q8

This commit is contained in:
Patrick Devine
2026-01-27 18:49:06 -08:00
parent 1a95093e0a
commit 9638fda956
13 changed files with 517 additions and 220 deletions

View File

@@ -1553,7 +1553,7 @@ func TestShowInfoImageGen(t *testing.T) {
Details: api.ModelDetails{
Family: "ZImagePipeline",
ParameterSize: "10.3B",
QuantizationLevel: "FP8",
QuantizationLevel: "Q8",
},
Capabilities: []model.Capability{model.CapabilityImage},
Requires: "0.14.0",
@@ -1565,7 +1565,7 @@ func TestShowInfoImageGen(t *testing.T) {
expect := " Model\n" +
" architecture ZImagePipeline \n" +
" parameters 10.3B \n" +
" quantization FP8 \n" +
" quantization Q8 \n" +
" requires 0.14.0 \n" +
"\n" +
" Capabilities\n" +

View File

@@ -589,7 +589,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
runner := &runnerRef{
model: req.model,
modelPath: req.model.Name,
modelPath: req.model.ModelPath,
llama: server,
Options: &req.opts,
loading: false,
@@ -599,7 +599,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
}
s.loadedMu.Lock()
s.loaded[req.model.Name] = runner
s.loaded[req.model.ModelPath] = runner
s.loadedMu.Unlock()
// Set up expiration timer

View File

@@ -35,7 +35,7 @@ type ModelfileConfig struct {
type CreateOptions struct {
ModelName string
ModelDir string
Quantize string // "fp8" for quantization
Quantize string // "q4", "q8", "nvfp4", or "mxfp8" for quantization
Modelfile *ModelfileConfig // template/system/license from Modelfile
}

View File

@@ -14,9 +14,9 @@ import (
// quantizeTensor loads a tensor from safetensors format, quantizes it,
// and returns safetensors data for the quantized weights, scales, and biases.
// Supported quantization types:
// - "fp4": affine 4-bit, group_size=32 (with qbiases)
// - "q4": affine 4-bit, group_size=32 (with qbiases)
// - "nvfp4": NVIDIA FP4, group_size=16 (no qbiases, E4M3 scales)
// - "fp8": affine 8-bit, group_size=32 (with qbiases)
// - "q8": affine 8-bit, group_size=64 (with qbiases)
// - "mxfp8": Microsoft MX FP8, group_size=32 (no qbiases, E4M3 scales)
// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
@@ -58,15 +58,15 @@ func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize str
// Quantize based on quantization type
var qweight, scales, qbiases *mlx.Array
switch quantize {
case "fp4":
case "q4":
// affine mode: group_size=32, bits=4 (with qbiases for zero-point offset)
qweight, scales, qbiases = mlx.Quantize(arr, 32, 4, "affine")
case "nvfp4":
// NVIDIA FP4: group_size=16, bits=4 (no qbiases, E4M3 scales)
qweight, scales, qbiases = mlx.Quantize(arr, 16, 4, "nvfp4")
case "fp8":
// affine mode: group_size=32, bits=8 (with qbiases for zero-point offset)
qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "affine")
case "q8":
// affine mode: group_size=64, bits=8 (with qbiases for zero-point offset)
qweight, scales, qbiases = mlx.Quantize(arr, 64, 8, "affine")
case "mxfp8":
// Microsoft MX FP8: group_size=32, bits=8, E4M3 scales (no qbiases)
qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "mxfp8")

View File

@@ -228,7 +228,7 @@ type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
// When quantize is non-empty (e.g., "fp8"), returns multiple layers (weight + scales + biases).
// When quantize is non-empty (e.g., "q8"), returns multiple layers (weight + scales + biases).
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error)
// ManifestWriter writes the manifest file.
@@ -264,40 +264,132 @@ func ShouldQuantize(name, component string) bool {
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name, shape, and quantize type.
// This is a more detailed check that also considers tensor dimensions.
// The quantize parameter specifies the quantization type (e.g., "fp4", "nvfp4", "fp8", "mxfp8").
// The quantize parameter specifies the quantization type (e.g., "q4", "nvfp4", "q8", "mxfp8").
func ShouldQuantizeTensor(name string, shape []int32, quantize string) bool {
return GetTensorQuantization(name, shape, quantize) != ""
}
// normalizeQuantType converts various quantization type aliases to canonical forms.
// Supports: q4/Q4/int4/INT4/fp4/FP4 -> q4, q8/Q8/int8/INT8/fp8/FP8 -> q8, nvfp4/NVFP4, mxfp8/MXFP8
func normalizeQuantType(quantize string) string {
switch strings.ToUpper(quantize) {
case "Q4", "INT4", "FP4":
return "q4"
case "Q8", "INT8", "FP8":
return "q8"
case "NVFP4":
return "nvfp4"
case "MXFP8":
return "mxfp8"
default:
return quantize
}
}
// getQuantGroupSize returns the group size for a given quantization type.
// These must match the values used in quantize.go when creating quantized models.
func getQuantGroupSize(quantize string) int {
switch normalizeQuantType(quantize) {
case "nvfp4":
return 16
case "q4":
return 32
case "mxfp8":
return 32
case "q8":
return 64
default:
return 32
}
}
// GetTensorQuantization returns the appropriate quantization type for a tensor.
// Returns "" if the tensor should not be quantized.
// This implements mixed-precision quantization:
// - Attention MLA weights (q_a, q_b, kv_a, kv_b): unquantized (most sensitive)
// - Output projection, gate/up weights: q4 (less sensitive)
// - Down projection weights: q8 (more sensitive, would be Q6 in GGML but no MLX kernel)
// - Norms, embeddings, biases, routing gates: no quantization
func GetTensorQuantization(name string, shape []int32, quantize string) string {
// Use basic name-based check first
if !ShouldQuantize(name, "") {
return false
return ""
}
// Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any)
if len(shape) != 2 {
return false
return ""
}
// Skip small tensors (less than 1024 elements) - not worth quantizing
if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 {
return false
return ""
}
// Normalize quantization type to canonical form
quantNorm := normalizeQuantType(quantize)
// MLX quantization requires last dimension to be divisible by group size
// NVFP4 uses group_size=16, all other modes (fp4, fp8, mxfp8) use group_size=32
// nvfp4: 16, q4/mxfp8: 32, q8: 64
groupSize := int32(32)
if strings.ToUpper(quantize) == "NVFP4" {
switch quantNorm {
case "nvfp4":
groupSize = 16
case "q8":
groupSize = 64
}
if shape[len(shape)-1]%groupSize != 0 {
return false
return ""
}
return true
// Skip routing gate weights (should stay high precision)
// In safetensors these are: mlp.gate.weight (not mlp.gate_proj.weight)
if strings.Contains(name, "mlp.gate.weight") && !strings.Contains(name, "_proj") {
return ""
}
// For NVFP4 or MXFP8, use the same quantization for all (no mixed precision)
if quantNorm == "nvfp4" || quantNorm == "mxfp8" {
return quantNorm
}
// Attention MLA weights - keep unquantized (bf16)
// These are highly sensitive: errors accumulate in the KV cache over time
// q_a_proj, q_b_proj, kv_a_proj_with_mqa, kv_b_proj
if strings.Contains(name, "q_a_proj") ||
strings.Contains(name, "q_b_proj") ||
strings.Contains(name, "kv_a_proj") ||
strings.Contains(name, "kv_b_proj") {
return "" // No quantization - keep bf16
}
// Down projection weights - use Q8 (would be Q6_K in GGML, but MLX has no Q6 kernel)
// mlp.down_proj, mlp.experts.X.down_proj, mlp.shared_experts.down_proj
if strings.Contains(name, "down_proj") {
return "q8"
}
// Output projection, gate/up weights - use requested quantization (Q4)
// o_proj, gate_proj, up_proj
if strings.Contains(name, "o_proj") ||
strings.Contains(name, "gate_proj") ||
strings.Contains(name, "up_proj") {
return quantNorm
}
// LM head - use requested quantization
if strings.Contains(name, "lm_head") {
return quantNorm
}
// Default to requested quantization for other weights
return quantNorm
}
// CreateSafetensorsModel imports a standard safetensors model from a directory.
// This handles Hugging Face style models with config.json and *.safetensors files.
// Stores each tensor as a separate blob for fine-grained deduplication.
// If quantize is non-empty (e.g., "fp8"), eligible tensors will be quantized.
// If quantize is non-empty (e.g., "q8"), eligible tensors will be quantized.
func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
var layers []LayerInfo
var configLayer LayerInfo
@@ -336,9 +428,10 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
}
// Determine quantization type for this tensor (empty string if not quantizing)
// GetTensorQuantization handles mixed-precision (e.g., Q8 for attention, Q4 for FFN)
quantizeType := ""
if quantize != "" && ShouldQuantizeTensor(tensorName, td.Shape, quantize) {
quantizeType = quantize
if quantize != "" {
quantizeType = GetTensorQuantization(tensorName, td.Shape, quantize)
}
// Store as minimal safetensors format (88 bytes header overhead)
@@ -398,6 +491,7 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
if quantize != "" {
modelIndex := map[string]any{
"quantization": strings.ToUpper(quantize),
"group_size": getQuantGroupSize(quantize),
}
indexData, err := json.MarshalIndent(modelIndex, "", " ")
if err != nil {

View File

@@ -15,15 +15,15 @@ import (
// CreateImageGenModel imports an image generation model from a directory.
// Stores each tensor as a separate blob for fine-grained deduplication.
// If quantize is specified, linear weights in transformer/text_encoder are quantized.
// Supported quantization types: fp4, fp8, nvfp4, mxfp8 (or empty for no quantization).
// Supported quantization types: q4, q8, nvfp4, mxfp8 (or empty for no quantization).
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
// Validate quantization type
switch quantize {
case "", "fp4", "fp8", "nvfp4", "mxfp8":
case "", "q4", "q8", "nvfp4", "mxfp8":
// valid
default:
return fmt.Errorf("unsupported quantization type %q: supported types are fp4, fp8, nvfp4, mxfp8", quantize)
return fmt.Errorf("unsupported quantization type %q: supported types are q4, q8, nvfp4, mxfp8", quantize)
}
var layers []LayerInfo
@@ -214,14 +214,17 @@ func CreateImageGenModel(modelName, modelDir, quantize string, createLayer Layer
// canQuantizeShape returns true if a tensor shape is compatible with MLX quantization.
// MLX requires the last dimension to be divisible by the group size.
// NVFP4 uses group_size=16, all other modes use group_size=32.
// nvfp4: 16, q4/mxfp8: 32, q8: 64
func canQuantizeShape(shape []int32, quantize string) bool {
if len(shape) < 2 {
return false
}
groupSize := int32(32)
if strings.ToUpper(quantize) == "NVFP4" {
switch strings.ToUpper(quantize) {
case "NVFP4":
groupSize = 16
case "Q8":
groupSize = 64
}
return shape[len(shape)-1]%groupSize == 0
}

View File

@@ -19,7 +19,6 @@ import (
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/flux2"
"github.com/ollama/ollama/x/imagegen/models/gemma3"
"github.com/ollama/ollama/x/imagegen/models/glm4_moe_lite"
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
"github.com/ollama/ollama/x/imagegen/models/llama"
"github.com/ollama/ollama/x/imagegen/models/zimage"
@@ -243,8 +242,6 @@ func load(modelPath string) (Model, error) {
return gemma3.Load(modelPath)
case "gemma3_text":
return gemma3.LoadText(modelPath)
case "glm4_moe_lite":
return glm4_moe_lite.Load(modelPath)
default:
return llama.Load(modelPath)
}

View File

@@ -209,7 +209,7 @@ func GetModelInfo(modelName string) (*ModelInfo, error) {
if info.Quantization == "" {
for _, layer := range manifest.Manifest.Layers {
if strings.HasSuffix(layer.Name, ".weight_scale") {
info.Quantization = "FP8"
info.Quantization = "Q8"
break
}
}

View File

@@ -8,8 +8,6 @@ import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
@@ -59,6 +57,11 @@ type Config struct {
// RoPE scaling
RopeScaling *RopeScaling `json:"rope_scaling"`
// Quantization parameters (set during load based on model quantization)
QuantGroupSize int `json:"-"` // Group size for quantization (default 64)
QuantBits int `json:"-"` // Bits per weight (4 or 8)
QuantMode string `json:"-"` // Quantization mode ("affine", etc.)
// Computed fields
QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim
Scale float32 `json:"-"` // 1/sqrt(QHeadDim) with mscale adjustment
@@ -236,9 +239,28 @@ func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) {
// SwitchMLP implements the MoE expert computation using stacked weights
// Note: No weight tags - these are populated manually by stacking expert weights
type SwitchMLP struct {
// Dequantized weights (used when GatherQMM not available)
GateWeight *mlx.Array
UpWeight *mlx.Array
DownWeight *mlx.Array
// Quantized weights (used with GatherQMM for 4/8-bit affine)
GateWeightQ, GateScales, GateBiases *mlx.Array
UpWeightQ, UpScales, UpBiases *mlx.Array
DownWeightQ, DownScales, DownBiases *mlx.Array
// Quantization bits per projection (supports mixed precision Q4/Q8)
GateBits int
UpBits int
DownBits int
// Quantization group size per projection (detected from tensor shapes)
GateGroupSize int
UpGroupSize int
DownGroupSize int
// If true, use GatherQMM with quantized weights
UseQuantized bool
}
// Forward applies the switched expert MLP
@@ -270,17 +292,29 @@ func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.
idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1)
}
// Expert computation using gather_mm
// gate: x @ gate_weight.T (indices are on the rhs/weight side)
gate := mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort)
// up: x @ up_weight.T
up := mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort)
var gate, up, hidden, down *mlx.Array
// SwiGLU activation
hidden := mlx.Mul(mlx.SiLU(gate), up)
if s.UseQuantized {
// Use GatherQMM for quantized weights (faster, keeps weights quantized)
// Each projection may have different bits and group sizes (mixed precision: Q4 for gate/up, Q8 for down)
gate = mlx.GatherQMM(xFlat, s.GateWeightQ, s.GateScales, s.GateBiases,
nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort)
up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
// down: hidden @ down_weight.T
down := mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort)
hidden = mlx.Mul(mlx.SiLU(gate), up)
down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
} else {
// Use GatherMM for dequantized/non-quantized weights
gate = mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort)
up = mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort)
hidden = mlx.Mul(mlx.SiLU(gate), up)
down = mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort)
}
// Unsort if we sorted
if doSort {
@@ -402,9 +436,36 @@ func computeScale(cfg *Config) float32 {
return scale
}
// loadExpertWeight loads an expert weight, dequantizing if necessary.
// GatherMM doesn't support quantized weights, so we must dequantize for MoE.
func loadExpertWeight(weights safetensors.WeightSource, path string) *mlx.Array {
// supportsGatherQMM returns true if the quantization mode has GatherQMM kernel support.
// Currently only 4-bit and 8-bit affine quantization are supported.
func supportsGatherQMM(mode string, bits int) bool {
return mode == "affine" && (bits == 4 || bits == 8)
}
// ExpertWeight holds a single expert's weight with optional quantization components.
type ExpertWeight struct {
Weight *mlx.Array // Quantized weight (if quantized) or dequantized weight
Scales *mlx.Array // Quantization scales (nil if not quantized)
Biases *mlx.Array // Quantization biases (nil if not quantized or mode doesn't use biases)
Bits int // Quantization bits (4 or 8), 0 if not quantized
GroupSize int // Quantization group size, 0 if not quantized
}
// getQuantParams returns quantization parameters from model metadata.
// Returns groupSize, bits, and mode for the model's quantization type.
func getQuantParams(weights safetensors.WeightSource) (groupSize, bits int, mode string) {
groupSize, bits, mode = safetensors.QuantizationParams(weights.Quantization())
// Use metadata group_size if available (overrides default)
if gs := weights.GroupSize(); gs > 0 {
groupSize = gs
}
return groupSize, bits, mode
}
// loadExpertWeight loads an expert weight.
// If useQuantized is true and the weight is quantized with a supported mode, returns quantized components.
// Otherwise dequantizes and returns only the weight.
func loadExpertWeight(weights safetensors.WeightSource, path string, useQuantized bool, cfg *Config) *ExpertWeight {
w, _ := weights.GetTensor(path + ".weight")
if w == nil {
return nil
@@ -419,12 +480,25 @@ func loadExpertWeight(weights safetensors.WeightSource, path string) *mlx.Array
if weights.HasTensor(qbiasPath) {
qbiases, _ = weights.GetTensor(qbiasPath)
}
// Dequantize using the model's quantization parameters
groupSize, bits, mode := safetensors.QuantizationParams(weights.Quantization())
return mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)
// Get quantization params from metadata
groupSize, bits, mode := getQuantParams(weights)
// Update config with group size (for GatherQMM calls)
if cfg.QuantGroupSize == 0 {
cfg.QuantGroupSize = groupSize
}
// If GatherQMM is supported and requested, return quantized components
if useQuantized && supportsGatherQMM(mode, bits) {
return &ExpertWeight{Weight: w, Scales: scales, Biases: qbiases, Bits: bits, GroupSize: groupSize}
}
// Otherwise dequantize
return &ExpertWeight{Weight: mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)}
}
return w
return &ExpertWeight{Weight: w}
}
// sanitizeMLAWeights transforms kv_b_proj weights into absorbed MLA format.
@@ -450,7 +524,8 @@ func sanitizeMLAWeights(weights safetensors.WeightSource, prefix string, cfg *Co
if weights.HasTensor(qbiasPath) {
qbiases, _ = weights.GetTensor(qbiasPath)
}
groupSize, bits, mode := safetensors.QuantizationParams(weights.Quantization())
groupSize, bits, mode := getQuantParams(weights)
w = mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)
}
@@ -477,140 +552,68 @@ func sanitizeMLAWeights(weights safetensors.WeightSource, prefix string, cfg *Co
return embedQ, unembedOut
}
// sanitizeExpertWeights stacks individual expert weights into a single tensor.
// For quantized models, expert weights are dequantized since GatherMM doesn't support quantized weights.
func sanitizeExpertWeights(weights safetensors.WeightSource, prefix string, numExperts int32) (*mlx.Array, *mlx.Array, *mlx.Array) {
var gateWeights, upWeights, downWeights []*mlx.Array
for e := int32(0); e < numExperts; e++ {
gw := loadExpertWeight(weights, fmt.Sprintf("%s.mlp.experts.%d.gate_proj", prefix, e))
uw := loadExpertWeight(weights, fmt.Sprintf("%s.mlp.experts.%d.up_proj", prefix, e))
dw := loadExpertWeight(weights, fmt.Sprintf("%s.mlp.experts.%d.down_proj", prefix, e))
if gw != nil {
gateWeights = append(gateWeights, gw)
}
if uw != nil {
upWeights = append(upWeights, uw)
}
if dw != nil {
downWeights = append(downWeights, dw)
}
}
var stackedGate, stackedUp, stackedDown *mlx.Array
if len(gateWeights) > 0 {
stackedGate = mlx.Stack(gateWeights, 0)
}
if len(upWeights) > 0 {
stackedUp = mlx.Stack(upWeights, 0)
}
if len(downWeights) > 0 {
stackedDown = mlx.Stack(downWeights, 0)
}
return stackedGate, stackedUp, stackedDown
// StackedExpertWeights holds stacked weights for all experts.
type StackedExpertWeights struct {
Weight *mlx.Array // Stacked weights [num_experts, out, in] or [num_experts, out, in_packed]
Scales *mlx.Array // Stacked scales (nil if not quantized)
Biases *mlx.Array // Stacked biases (nil if not quantized)
Bits int // Quantization bits (4 or 8), 0 if not quantized
GroupSize int // Quantization group size, 0 if not quantized
}
// Load loads a GLM4-MoE-Lite model from the given path
func Load(modelPath string) (*Model, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
// collectAndStackExpertWeights loads and stacks expert weights for one projection type.
func collectAndStackExpertWeights(
weights safetensors.WeightSource,
prefix string,
projName string,
numExperts int32,
useQuantized bool,
cfg *Config,
) *StackedExpertWeights {
var w, s, b []*mlx.Array
var bits, groupSize int
var cfg Config
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
// Compute derived fields
cfg.QHeadDim = cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
cfg.Scale = computeScale(&cfg)
weights, err := safetensors.LoadModelWeights(modelPath)
if err != nil {
return nil, fmt.Errorf("load weights: %w", err)
}
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
if err != nil {
return nil, fmt.Errorf("load tokenizer: %w", err)
}
m := &Model{
Layers: make([]Block, cfg.NumHiddenLayers),
Config: &cfg,
tok: tok,
}
// Load embedding, norm, and lm_head
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return nil, err
}
// Load layers manually due to different block types
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
prefix := fmt.Sprintf("model.layers.%d", i)
// Load attention (same for both block types)
attn := &MLAAttention{}
if err := safetensors.LoadModule(attn, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d attention: %w", i, err)
for e := int32(0); e < numExperts; e++ {
path := fmt.Sprintf("%s.mlp.experts.%d.%s", prefix, e, projName)
ew := loadExpertWeight(weights, path, useQuantized, cfg)
if ew == nil {
continue
}
// Sanitize MLA weights for absorbed attention
embedQ, unembedOut := sanitizeMLAWeights(weights, prefix, &cfg)
attn.EmbedQ = nn.NewMultiLinear(embedQ)
attn.UnembedOut = nn.NewMultiLinear(unembedOut)
if i < cfg.FirstKDenseReplace {
// Dense block
block := &DenseBlock{Attention: attn}
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d dense: %w", i, err)
}
m.Layers[i] = block
} else {
// MoE block
block := &MoEBlock{Attention: attn}
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d moe block: %w", i, err)
}
// Stack expert weights
gateW, upW, downW := sanitizeExpertWeights(weights, prefix, cfg.NRoutedExperts)
block.MoE = &MoE{
Gate: &MoEGate{},
SwitchMLP: &SwitchMLP{
GateWeight: gateW,
UpWeight: upW,
DownWeight: downW,
},
}
// Load gate weights
if err := safetensors.LoadModule(block.MoE.Gate, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d gate: %w", i, err)
}
// Load shared experts if present
if cfg.NSharedExperts > 0 {
block.MoE.SharedExperts = &SharedExperts{}
if err := safetensors.LoadModule(block.MoE.SharedExperts, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d shared experts: %w", i, err)
}
}
m.Layers[i] = block
w = append(w, ew.Weight)
if ew.Scales != nil {
s = append(s, ew.Scales)
}
if ew.Biases != nil {
b = append(b, ew.Biases)
}
if e == 0 {
bits = ew.Bits
groupSize = ew.GroupSize
}
}
mlx.Eval(mlx.Collect(m)...)
weights.ReleaseAll()
result := &StackedExpertWeights{Bits: bits, GroupSize: groupSize}
if len(w) > 0 {
result.Weight = mlx.Stack(w, 0)
if len(s) > 0 {
result.Scales = mlx.Stack(s, 0)
}
if len(b) > 0 {
result.Biases = mlx.Stack(b, 0)
}
}
return result
}
return m, nil
// sanitizeExpertWeights stacks individual expert weights into tensors.
// If useQuantized is true and weights support GatherQMM, returns quantized components.
// Otherwise returns dequantized weights with nil scales/biases.
// Bits and GroupSize are detected per-weight to support mixed-precision (Q4 for gate/up, Q8 for down).
func sanitizeExpertWeights(weights safetensors.WeightSource, prefix string, numExperts int32, useQuantized bool, cfg *Config) (gate, up, down *StackedExpertWeights) {
gate = collectAndStackExpertWeights(weights, prefix, "gate_proj", numExperts, useQuantized, cfg)
up = collectAndStackExpertWeights(weights, prefix, "up_proj", numExperts, useQuantized, cfg)
down = collectAndStackExpertWeights(weights, prefix, "down_proj", numExperts, useQuantized, cfg)
return gate, up, down
}
// LoadFromManifest loads a GLM4-MoE-Lite model from a manifest (Ollama blob storage).
@@ -636,19 +639,19 @@ func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) {
return nil, fmt.Errorf("load weights: %w", err)
}
// Debug: print quantization info and sample tensor names
fmt.Printf("GLM4: quantization=%q, num_tensors=%d\n", weights.Quantization(), len(weights.ListTensors()))
tensors := weights.ListTensors()
for i, name := range tensors {
if i < 20 { // Print first 20 tensor names
fmt.Printf(" tensor[%d]: %s\n", i, name)
}
}
if err := weights.Load(0); err != nil {
return nil, fmt.Errorf("load weight data: %w", err)
}
// Set up quantization parameters (only if model is actually quantized)
// Note: QuantGroupSize will be detected dynamically from tensor shapes during weight loading
quantization := weights.Quantization()
useQuantized := false
if quantization != "" {
_, cfg.QuantBits, cfg.QuantMode = safetensors.QuantizationParams(quantization)
useQuantized = supportsGatherQMM(cfg.QuantMode, cfg.QuantBits)
}
// Load tokenizer from manifest with config files for EOS token detection
tokData, err := manifest.ReadConfig("tokenizer.json")
if err != nil {
@@ -715,16 +718,35 @@ func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) {
return nil, fmt.Errorf("layer %d moe block: %w", i, err)
}
// Stack expert weights
gateW, upW, downW := sanitizeExpertWeights(weights, prefix, cfg.NRoutedExperts)
// Stack expert weights (pass cfg so group sizes can be detected)
gate, up, down := sanitizeExpertWeights(weights, prefix, cfg.NRoutedExperts, useQuantized, &cfg)
switchMLP := &SwitchMLP{UseQuantized: useQuantized}
if useQuantized {
switchMLP.GateWeightQ = gate.Weight
switchMLP.GateScales = gate.Scales
switchMLP.GateBiases = gate.Biases
switchMLP.GateBits = gate.Bits
switchMLP.GateGroupSize = gate.GroupSize
switchMLP.UpWeightQ = up.Weight
switchMLP.UpScales = up.Scales
switchMLP.UpBiases = up.Biases
switchMLP.UpBits = up.Bits
switchMLP.UpGroupSize = up.GroupSize
switchMLP.DownWeightQ = down.Weight
switchMLP.DownScales = down.Scales
switchMLP.DownBiases = down.Biases
switchMLP.DownBits = down.Bits
switchMLP.DownGroupSize = down.GroupSize
} else {
switchMLP.GateWeight = gate.Weight
switchMLP.UpWeight = up.Weight
switchMLP.DownWeight = down.Weight
}
block.MoE = &MoE{
Gate: &MoEGate{},
SwitchMLP: &SwitchMLP{
GateWeight: gateW,
UpWeight: upW,
DownWeight: downW,
},
Gate: &MoEGate{},
SwitchMLP: switchMLP,
}
// Load gate weights

View File

@@ -17,7 +17,8 @@ type WeightSource interface {
GetTensor(name string) (*mlx.Array, error)
ListTensors() []string
HasTensor(name string) bool
Quantization() string // Returns "NVFP4", "FP4", "FP8", or ""
Quantization() string // Returns "NVFP4", "Q4", "Q8", or ""
GroupSize() int // Returns quantization group size, or 0 if not specified
}
// QuantizationParams returns groupSize, bits, mode for a quantization type.
@@ -38,7 +39,7 @@ func QuantizationParams(quantization string) (groupSize, bits int, mode string)
return 32, 8, "mxfp8"
case "FP8", "Q8", "INT8", "":
// 8-bit quantization with affine mode (default for quantized models)
return 32, 8, "affine"
return 64, 8, "affine"
default:
return 32, 8, "affine" // Default to affine
}
@@ -273,8 +274,48 @@ func LoadMultiLinearLayer(weights WeightSource, path string) (nn.MultiLinearLaye
}
// Always dequantize for MultiLinear - no batched quantized matmul support
groupSize, bits, mode := QuantizationParams(weights.Quantization())
weight = mlx.Dequantize(weight, scales, qbiases, groupSize, bits, mode)
// Detect bits from tensor shapes (supports mixed-precision Q4/Q8)
weightShape := weight.Shape()
scalesShape := scales.Shape()
weightCols := int(weightShape[len(weightShape)-1])
scalesCols := int(scalesShape[len(scalesShape)-1])
// Detect quantization from tensor shapes
// groupSize = weightCols * packFactor / scalesCols
// Note: groupSize4 = 2 * groupSize8 always, so ambiguous cases need metadata
groupSize4 := weightCols * 8 / scalesCols
groupSize8 := weightCols * 4 / scalesCols
var bits, groupSize int
// Use metadata to help disambiguate when shapes are ambiguous
// (e.g., Q4 with group_size=64 has same shapes as Q8 with group_size=32)
quantType := strings.ToUpper(weights.Quantization())
isQ8Type := quantType == "Q8" || quantType == "FP8" || quantType == "INT8"
if groupSize4 == 32 {
// Unambiguous: Q4 with group_size=32
bits = 4
groupSize = 32
} else if groupSize8 == 64 {
// Unambiguous: Q8 with group_size=64
bits = 8
groupSize = 64
} else if groupSize4 == 64 && groupSize8 == 32 {
// Ambiguous: could be Q4/gs=64 or Q8/gs=32, use metadata
if isQ8Type {
bits = 8
groupSize = 32
} else {
bits = 4
groupSize = 64
}
} else {
// Fallback: use global quantization params
_, bits, _ = QuantizationParams(weights.Quantization())
packFactor := 32 / bits
groupSize = weightCols * packFactor / scalesCols
}
weight = mlx.Dequantize(weight, scales, qbiases, groupSize, bits, "affine")
}
return nn.NewMultiLinear(weight), nil
@@ -310,7 +351,48 @@ func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error)
qbiases, _ = weights.GetTensor(qbiasPath)
}
groupSize, bits, mode := QuantizationParams(weights.Quantization())
// Detect bits from tensor shapes (supports mixed-precision Q4/Q8)
weightShape := weight.Shape()
scalesShape := scales.Shape()
weightCols := int(weightShape[len(weightShape)-1])
scalesCols := int(scalesShape[len(scalesShape)-1])
// Detect quantization from tensor shapes
// groupSize = weightCols * packFactor / scalesCols
// Note: groupSize4 = 2 * groupSize8 always, so ambiguous cases need metadata
groupSize4 := weightCols * 8 / scalesCols
groupSize8 := weightCols * 4 / scalesCols
var bits, groupSize int
mode := "affine"
// Use metadata to help disambiguate when shapes are ambiguous
// (e.g., Q4 with group_size=64 has same shapes as Q8 with group_size=32)
quantType := strings.ToUpper(weights.Quantization())
isQ8Type := quantType == "Q8" || quantType == "FP8" || quantType == "INT8"
if groupSize4 == 32 {
// Unambiguous: Q4 with group_size=32
bits = 4
groupSize = 32
} else if groupSize8 == 64 {
// Unambiguous: Q8 with group_size=64
bits = 8
groupSize = 64
} else if groupSize4 == 64 && groupSize8 == 32 {
// Ambiguous: could be Q4/gs=64 or Q8/gs=32, use metadata
if isQ8Type {
bits = 8
groupSize = 32
} else {
bits = 4
groupSize = 64
}
} else {
// Fallback: use global quantization params
_, bits, mode = QuantizationParams(weights.Quantization())
packFactor := 32 / bits
groupSize = weightCols * packFactor / scalesCols
}
// NVFP4 and MXFP8 don't have native quantized matmul kernels in MLX,
// so we always dequantize at load time. Affine modes (FP4, FP8) have kernel support.

View File

@@ -303,6 +303,11 @@ func (mw *ModelWeights) Quantization() string {
return ""
}
// GroupSize returns 0 for directory-based weights (use default).
func (mw *ModelWeights) GroupSize() int {
return 0
}
// ReleaseAll releases all cached native file handles.
func (mw *ModelWeights) ReleaseAll() {
for path, native := range mw.nativeCache {

View File

@@ -209,19 +209,36 @@ func (mw *ManifestWeights) Quantization() string {
// So scale size should be ~weight_size * 4 / 32 = weight_size / 8
// Rough size heuristic (assuming float16 scales)
// FP4: scale_bytes ≈ weight_bytes / 4 * 2 / 4 = weight_bytes / 8
// FP8: scale_bytes ≈ weight_bytes / 8 * 2 / 4 = weight_bytes / 16
// Q4: scale_bytes ≈ weight_bytes / 4 * 2 / 4 = weight_bytes / 8
// Q8: scale_bytes ≈ weight_bytes / 8 * 2 / 4 = weight_bytes / 16
ratio := float64(weightLayer.Size) / float64(scaleLayer.Size)
if ratio < 12 {
// Closer to 8 = FP4
return "FP4"
// Closer to 8 = Q4
return "Q4"
}
// Closer to 16 = FP8
return "FP8"
// Closer to 16 = Q8
return "Q8"
}
// Default to FP4 for affine mode (most common)
return "FP4"
// Default to Q4 for affine mode (most common)
return "Q4"
}
// GroupSize returns the quantization group size from model_index.json.
// Returns 0 if not specified (caller should use default based on quantization type).
func (mw *ManifestWeights) GroupSize() int {
if mw.manifest == nil {
return 0
}
var index struct {
GroupSize int `json:"group_size"`
}
if err := mw.manifest.ReadConfigJSON("model_index.json", &index); err == nil && index.GroupSize > 0 {
return index.GroupSize
}
return 0
}
// ReleaseAll frees all native handles and clears the tensor cache.

View File

@@ -163,9 +163,18 @@ func GetSafetensorsTensorInfo(name model.Name) ([]api.Tensor, error) {
// getTensorInfoFromManifest extracts tensor info from a manifest.
// This is separated for testability.
// For quantized models, groups weight/scale/qbias into single entries with detected quantization type.
func getTensorInfoFromManifest(mf *manifest.Manifest) ([]api.Tensor, error) {
var tensors []api.Tensor
// First pass: collect all tensor info and identify scale tensors
type tensorData struct {
info *safetensorsTensorInfo
digest string
}
tensorMap := make(map[string]*tensorData)
scaleMap := make(map[string]*tensorData) // base name -> scale tensor info
for _, layer := range mf.Layers {
if layer.MediaType != manifest.MediaTypeImageTensor {
continue
@@ -178,21 +187,89 @@ func getTensorInfoFromManifest(mf *manifest.Manifest) ([]api.Tensor, error) {
}
info, err := readSafetensorsHeader(blobPath)
if err != nil {
// Skip tensors we can't read
continue
}
// Convert shape from int to uint64
shape := make([]uint64, len(info.Shape))
for i, s := range info.Shape {
shape[i] = uint64(s)
td := &tensorData{info: info, digest: layer.Digest}
if strings.HasSuffix(layer.Name, "_scale") {
baseName := strings.TrimSuffix(layer.Name, "_scale")
scaleMap[baseName] = td
} else if strings.HasSuffix(layer.Name, "_qbias") {
// Skip qbias tensors - they're included with the quantized weight
continue
} else {
tensorMap[layer.Name] = td
}
}
// Second pass: build tensor list with quantization info
for _, layer := range mf.Layers {
if layer.MediaType != manifest.MediaTypeImageTensor {
continue
}
tensors = append(tensors, api.Tensor{
Name: layer.Name,
Type: info.Dtype,
Shape: shape,
})
// Skip scale and qbias tensors
if strings.HasSuffix(layer.Name, "_scale") || strings.HasSuffix(layer.Name, "_qbias") {
continue
}
td := tensorMap[layer.Name]
if td == nil {
continue
}
// Check if this tensor has a corresponding scale tensor (quantized)
scaleTd := scaleMap[layer.Name]
if scaleTd != nil && len(td.info.Shape) >= 2 && len(scaleTd.info.Shape) >= 2 {
// Quantized tensor - detect bits from shapes
weightCols := td.info.Shape[len(td.info.Shape)-1]
scaleCols := scaleTd.info.Shape[len(scaleTd.info.Shape)-1]
// Detect quantization: Q4 has pack_factor=8, Q8 has pack_factor=4
// Q4 uses group_size=32: weightCols * 8 / scaleCols = 32
// Q8 uses group_size=64: weightCols * 4 / scaleCols = 64
var bits int
var quantType string
if weightCols*8/scaleCols == 32 {
bits = 4
quantType = "Q4"
} else if weightCols*4/scaleCols == 64 {
bits = 8
quantType = "Q8"
} else {
// Unknown quantization, show raw
quantType = td.info.Dtype
}
// Calculate unpacked shape
shape := make([]uint64, len(td.info.Shape))
for i, s := range td.info.Shape {
shape[i] = uint64(s)
}
if bits > 0 {
packFactor := int64(32 / bits)
shape[len(shape)-1] = uint64(td.info.Shape[len(td.info.Shape)-1] * packFactor)
}
tensors = append(tensors, api.Tensor{
Name: layer.Name,
Type: quantType,
Shape: shape,
})
} else {
// Non-quantized tensor
shape := make([]uint64, len(td.info.Shape))
for i, s := range td.info.Shape {
shape[i] = uint64(s)
}
tensors = append(tensors, api.Tensor{
Name: layer.Name,
Type: td.info.Dtype,
Shape: shape,
})
}
}
return tensors, nil
@@ -231,9 +308,9 @@ func GetSafetensorsDtype(name model.Name) (string, error) {
if hasScales {
if hasQBias {
// Affine mode (has scale + qbias) - could be FP4 or FP8
// Default to FP4 as it's more common
return "FP4", nil
// Affine mode (has scale + qbias) - could be Q4 or Q8
// Default to Q4 as it's more common
return "Q4", nil
}
// No qbias = NVFP4
return "NVFP4", nil