mirror of
https://github.com/ollama/ollama.git
synced 2026-01-29 07:12:03 +03:00
x/imagegen: add FP4 quantization support for image generation models (#13773)
Add --quantize fp4 support to ollama create for image generation models (flux2, z-image-turbo), using MLX's affine 4-bit quantization. Changes: - Add fp4 to validation in CreateImageGenModel - Add FP4 case to quantizeTensor (group_size=32, bits=4, affine mode) - Add GetQuantization() to WeightSource interface for dynamic params - Update LoadLinearLayer to use quantization params from model metadata
This commit is contained in:
@@ -54,6 +54,9 @@ func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize str
|
||||
// Quantize based on quantization type
|
||||
var qweight, scales, qbiases *mlx.Array
|
||||
switch quantize {
|
||||
case "fp4":
|
||||
// affine mode: group_size=32, bits=4
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 32, 4, "affine")
|
||||
case "fp8":
|
||||
// affine mode: group_size=32, bits=8
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "affine")
|
||||
|
||||
@@ -20,10 +20,10 @@ import (
|
||||
func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
// Validate quantization type
|
||||
switch quantize {
|
||||
case "", "fp8":
|
||||
case "", "fp4", "fp8":
|
||||
// valid
|
||||
default:
|
||||
return fmt.Errorf("unsupported quantization type %q: supported types are fp8", quantize)
|
||||
return fmt.Errorf("unsupported quantization type %q: supported types are fp4, fp8", quantize)
|
||||
}
|
||||
|
||||
var layers []LayerInfo
|
||||
|
||||
@@ -17,6 +17,18 @@ type WeightSource interface {
|
||||
GetTensor(name string) (*mlx.Array, error)
|
||||
ListTensors() []string
|
||||
HasTensor(name string) bool
|
||||
Quantization() string // Returns "FP4", "FP8", or ""
|
||||
}
|
||||
|
||||
// quantizationParams returns groupSize, bits, mode for a quantization type.
|
||||
// Returns defaults (32, 8, "affine") for unknown types (backward compatibility).
|
||||
func quantizationParams(quantization string) (groupSize, bits int, mode string) {
|
||||
switch strings.ToUpper(quantization) {
|
||||
case "FP4":
|
||||
return 32, 4, "affine"
|
||||
default:
|
||||
return 32, 8, "affine" // FP8 or unknown
|
||||
}
|
||||
}
|
||||
|
||||
// Transformer allows structs to transform weight arrays before assignment.
|
||||
@@ -233,19 +245,21 @@ func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error)
|
||||
qbiases, _ = weights.GetTensor(qbiasPath)
|
||||
}
|
||||
|
||||
groupSize, bits, mode := quantizationParams(weights.Quantization())
|
||||
|
||||
if mlx.MetalIsAvailable() {
|
||||
return &nn.QuantizedLinear{
|
||||
Weight: weight,
|
||||
Scales: scales,
|
||||
QBiases: qbiases,
|
||||
Bias: bias,
|
||||
GroupSize: 32,
|
||||
Bits: 8,
|
||||
Mode: "affine",
|
||||
GroupSize: groupSize,
|
||||
Bits: bits,
|
||||
Mode: mode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
dequantized := mlx.Dequantize(weight, scales, qbiases, 32, 8, "affine")
|
||||
dequantized := mlx.Dequantize(weight, scales, qbiases, groupSize, bits, mode)
|
||||
return nn.NewLinear(dequantized, bias), nil
|
||||
}
|
||||
|
||||
|
||||
@@ -298,6 +298,11 @@ func (mw *ModelWeights) HasTensor(name string) bool {
|
||||
return ok
|
||||
}
|
||||
|
||||
// Quantization returns empty string for directory-based weights (not quantized).
|
||||
func (mw *ModelWeights) Quantization() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// ReleaseAll releases all cached native file handles.
|
||||
func (mw *ModelWeights) ReleaseAll() {
|
||||
for path, native := range mw.nativeCache {
|
||||
|
||||
@@ -106,6 +106,21 @@ func (mw *ManifestWeights) HasTensor(name string) bool {
|
||||
return ok
|
||||
}
|
||||
|
||||
// Quantization returns the model's quantization type from model_index.json.
|
||||
// Returns empty string if not quantized or unknown.
|
||||
func (mw *ManifestWeights) Quantization() string {
|
||||
if mw.manifest == nil {
|
||||
return ""
|
||||
}
|
||||
var index struct {
|
||||
Quantization string `json:"quantization"`
|
||||
}
|
||||
if err := mw.manifest.ReadConfigJSON("model_index.json", &index); err != nil {
|
||||
return ""
|
||||
}
|
||||
return index.Quantization
|
||||
}
|
||||
|
||||
// ReleaseAll frees all native handles and clears the tensor cache.
|
||||
func (mw *ManifestWeights) ReleaseAll() {
|
||||
for _, sf := range mw.nativeCache {
|
||||
|
||||
Reference in New Issue
Block a user