From 03bf241c33b4bc20a22e58e6133635a59caae94e Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Mon, 19 Jan 2026 00:54:54 -0800 Subject: [PATCH] x/imagegen: add FP4 quantization support for image generation models (#13773) Add --quantize fp4 support to ollama create for image generation models (flux2, z-image-turbo), using MLX's affine 4-bit quantization. Changes: - Add fp4 to validation in CreateImageGenModel - Add FP4 case to quantizeTensor (group_size=32, bits=4, affine mode) - Add GetQuantization() to WeightSource interface for dynamic params - Update LoadLinearLayer to use quantization params from model metadata --- x/create/client/quantize.go | 3 +++ x/create/imagegen.go | 4 ++-- x/imagegen/safetensors/loader.go | 22 ++++++++++++++++++---- x/imagegen/safetensors/safetensors.go | 5 +++++ x/imagegen/weights.go | 15 +++++++++++++++ 5 files changed, 43 insertions(+), 6 deletions(-) diff --git a/x/create/client/quantize.go b/x/create/client/quantize.go index 5a4be59d0..3a9f37cfc 100644 --- a/x/create/client/quantize.go +++ b/x/create/client/quantize.go @@ -54,6 +54,9 @@ func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize str // Quantize based on quantization type var qweight, scales, qbiases *mlx.Array switch quantize { + case "fp4": + // affine mode: group_size=32, bits=4 + qweight, scales, qbiases = mlx.Quantize(arr, 32, 4, "affine") case "fp8": // affine mode: group_size=32, bits=8 qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "affine") diff --git a/x/create/imagegen.go b/x/create/imagegen.go index ad10d8c69..595a40417 100644 --- a/x/create/imagegen.go +++ b/x/create/imagegen.go @@ -20,10 +20,10 @@ import ( func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error { // Validate quantization type switch quantize { - case "", "fp8": + case "", "fp4", "fp8": // valid default: - return fmt.Errorf("unsupported quantization type %q: supported types are fp8", quantize) + return fmt.Errorf("unsupported quantization type %q: supported types are fp4, fp8", quantize) } var layers []LayerInfo diff --git a/x/imagegen/safetensors/loader.go b/x/imagegen/safetensors/loader.go index df21623ed..7f8860b06 100644 --- a/x/imagegen/safetensors/loader.go +++ b/x/imagegen/safetensors/loader.go @@ -17,6 +17,18 @@ type WeightSource interface { GetTensor(name string) (*mlx.Array, error) ListTensors() []string HasTensor(name string) bool + Quantization() string // Returns "FP4", "FP8", or "" +} + +// quantizationParams returns groupSize, bits, mode for a quantization type. +// Returns defaults (32, 8, "affine") for unknown types (backward compatibility). +func quantizationParams(quantization string) (groupSize, bits int, mode string) { + switch strings.ToUpper(quantization) { + case "FP4": + return 32, 4, "affine" + default: + return 32, 8, "affine" // FP8 or unknown + } } // Transformer allows structs to transform weight arrays before assignment. @@ -233,19 +245,21 @@ func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error) qbiases, _ = weights.GetTensor(qbiasPath) } + groupSize, bits, mode := quantizationParams(weights.Quantization()) + if mlx.MetalIsAvailable() { return &nn.QuantizedLinear{ Weight: weight, Scales: scales, QBiases: qbiases, Bias: bias, - GroupSize: 32, - Bits: 8, - Mode: "affine", + GroupSize: groupSize, + Bits: bits, + Mode: mode, }, nil } - dequantized := mlx.Dequantize(weight, scales, qbiases, 32, 8, "affine") + dequantized := mlx.Dequantize(weight, scales, qbiases, groupSize, bits, mode) return nn.NewLinear(dequantized, bias), nil } diff --git a/x/imagegen/safetensors/safetensors.go b/x/imagegen/safetensors/safetensors.go index f7ac327ce..a36052fce 100644 --- a/x/imagegen/safetensors/safetensors.go +++ b/x/imagegen/safetensors/safetensors.go @@ -298,6 +298,11 @@ func (mw *ModelWeights) HasTensor(name string) bool { return ok } +// Quantization returns empty string for directory-based weights (not quantized). +func (mw *ModelWeights) Quantization() string { + return "" +} + // ReleaseAll releases all cached native file handles. func (mw *ModelWeights) ReleaseAll() { for path, native := range mw.nativeCache { diff --git a/x/imagegen/weights.go b/x/imagegen/weights.go index 377707fbf..f49c7e77e 100644 --- a/x/imagegen/weights.go +++ b/x/imagegen/weights.go @@ -106,6 +106,21 @@ func (mw *ManifestWeights) HasTensor(name string) bool { return ok } +// Quantization returns the model's quantization type from model_index.json. +// Returns empty string if not quantized or unknown. +func (mw *ManifestWeights) Quantization() string { + if mw.manifest == nil { + return "" + } + var index struct { + Quantization string `json:"quantization"` + } + if err := mw.manifest.ReadConfigJSON("model_index.json", &index); err != nil { + return "" + } + return index.Quantization +} + // ReleaseAll frees all native handles and clears the tensor cache. func (mw *ManifestWeights) ReleaseAll() { for _, sf := range mw.nativeCache {