x/imagegen: add FP4 quantization support for image generation models (#13773)

Add --quantize fp4 support to ollama create for image generation models
(flux2, z-image-turbo), using MLX's affine 4-bit quantization.

Changes:
- Add fp4 to validation in CreateImageGenModel
- Add FP4 case to quantizeTensor (group_size=32, bits=4, affine mode)
- Add GetQuantization() to WeightSource interface for dynamic params
- Update LoadLinearLayer to use quantization params from model metadata
This commit is contained in:
Jeffrey Morgan
2026-01-19 00:54:54 -08:00
committed by GitHub
parent a887406c24
commit 03bf241c33
5 changed files with 43 additions and 6 deletions

View File

@@ -54,6 +54,9 @@ func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize str
// Quantize based on quantization type
var qweight, scales, qbiases *mlx.Array
switch quantize {
case "fp4":
// affine mode: group_size=32, bits=4
qweight, scales, qbiases = mlx.Quantize(arr, 32, 4, "affine")
case "fp8":
// affine mode: group_size=32, bits=8
qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "affine")

View File

@@ -20,10 +20,10 @@ import (
func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
// Validate quantization type
switch quantize {
case "", "fp8":
case "", "fp4", "fp8":
// valid
default:
return fmt.Errorf("unsupported quantization type %q: supported types are fp8", quantize)
return fmt.Errorf("unsupported quantization type %q: supported types are fp4, fp8", quantize)
}
var layers []LayerInfo

View File

@@ -17,6 +17,18 @@ type WeightSource interface {
GetTensor(name string) (*mlx.Array, error)
ListTensors() []string
HasTensor(name string) bool
Quantization() string // Returns "FP4", "FP8", or ""
}
// quantizationParams returns groupSize, bits, mode for a quantization type.
// Returns defaults (32, 8, "affine") for unknown types (backward compatibility).
func quantizationParams(quantization string) (groupSize, bits int, mode string) {
switch strings.ToUpper(quantization) {
case "FP4":
return 32, 4, "affine"
default:
return 32, 8, "affine" // FP8 or unknown
}
}
// Transformer allows structs to transform weight arrays before assignment.
@@ -233,19 +245,21 @@ func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error)
qbiases, _ = weights.GetTensor(qbiasPath)
}
groupSize, bits, mode := quantizationParams(weights.Quantization())
if mlx.MetalIsAvailable() {
return &nn.QuantizedLinear{
Weight: weight,
Scales: scales,
QBiases: qbiases,
Bias: bias,
GroupSize: 32,
Bits: 8,
Mode: "affine",
GroupSize: groupSize,
Bits: bits,
Mode: mode,
}, nil
}
dequantized := mlx.Dequantize(weight, scales, qbiases, 32, 8, "affine")
dequantized := mlx.Dequantize(weight, scales, qbiases, groupSize, bits, mode)
return nn.NewLinear(dequantized, bias), nil
}

View File

@@ -298,6 +298,11 @@ func (mw *ModelWeights) HasTensor(name string) bool {
return ok
}
// Quantization returns empty string for directory-based weights (not quantized).
func (mw *ModelWeights) Quantization() string {
return ""
}
// ReleaseAll releases all cached native file handles.
func (mw *ModelWeights) ReleaseAll() {
for path, native := range mw.nativeCache {

View File

@@ -106,6 +106,21 @@ func (mw *ManifestWeights) HasTensor(name string) bool {
return ok
}
// Quantization returns the model's quantization type from model_index.json.
// Returns empty string if not quantized or unknown.
func (mw *ManifestWeights) Quantization() string {
if mw.manifest == nil {
return ""
}
var index struct {
Quantization string `json:"quantization"`
}
if err := mw.manifest.ReadConfigJSON("model_index.json", &index); err != nil {
return ""
}
return index.Quantization
}
// ReleaseAll frees all native handles and clears the tensor cache.
func (mw *ManifestWeights) ReleaseAll() {
for _, sf := range mw.nativeCache {