diff --git a/x/create/client/quantize.go b/x/create/client/quantize.go index 5a4be59d0..3a9f37cfc 100644 --- a/x/create/client/quantize.go +++ b/x/create/client/quantize.go @@ -54,6 +54,9 @@ func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize str // Quantize based on quantization type var qweight, scales, qbiases *mlx.Array switch quantize { + case "fp4": + // affine mode: group_size=32, bits=4 + qweight, scales, qbiases = mlx.Quantize(arr, 32, 4, "affine") case "fp8": // affine mode: group_size=32, bits=8 qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "affine") diff --git a/x/create/imagegen.go b/x/create/imagegen.go index ad10d8c69..595a40417 100644 --- a/x/create/imagegen.go +++ b/x/create/imagegen.go @@ -20,10 +20,10 @@ import ( func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error { // Validate quantization type switch quantize { - case "", "fp8": + case "", "fp4", "fp8": // valid default: - return fmt.Errorf("unsupported quantization type %q: supported types are fp8", quantize) + return fmt.Errorf("unsupported quantization type %q: supported types are fp4, fp8", quantize) } var layers []LayerInfo diff --git a/x/imagegen/safetensors/loader.go b/x/imagegen/safetensors/loader.go index df21623ed..7f8860b06 100644 --- a/x/imagegen/safetensors/loader.go +++ b/x/imagegen/safetensors/loader.go @@ -17,6 +17,18 @@ type WeightSource interface { GetTensor(name string) (*mlx.Array, error) ListTensors() []string HasTensor(name string) bool + Quantization() string // Returns "FP4", "FP8", or "" +} + +// quantizationParams returns groupSize, bits, mode for a quantization type. +// Returns defaults (32, 8, "affine") for unknown types (backward compatibility). +func quantizationParams(quantization string) (groupSize, bits int, mode string) { + switch strings.ToUpper(quantization) { + case "FP4": + return 32, 4, "affine" + default: + return 32, 8, "affine" // FP8 or unknown + } } // Transformer allows structs to transform weight arrays before assignment. @@ -233,19 +245,21 @@ func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error) qbiases, _ = weights.GetTensor(qbiasPath) } + groupSize, bits, mode := quantizationParams(weights.Quantization()) + if mlx.MetalIsAvailable() { return &nn.QuantizedLinear{ Weight: weight, Scales: scales, QBiases: qbiases, Bias: bias, - GroupSize: 32, - Bits: 8, - Mode: "affine", + GroupSize: groupSize, + Bits: bits, + Mode: mode, }, nil } - dequantized := mlx.Dequantize(weight, scales, qbiases, 32, 8, "affine") + dequantized := mlx.Dequantize(weight, scales, qbiases, groupSize, bits, mode) return nn.NewLinear(dequantized, bias), nil } diff --git a/x/imagegen/safetensors/safetensors.go b/x/imagegen/safetensors/safetensors.go index f7ac327ce..a36052fce 100644 --- a/x/imagegen/safetensors/safetensors.go +++ b/x/imagegen/safetensors/safetensors.go @@ -298,6 +298,11 @@ func (mw *ModelWeights) HasTensor(name string) bool { return ok } +// Quantization returns empty string for directory-based weights (not quantized). +func (mw *ModelWeights) Quantization() string { + return "" +} + // ReleaseAll releases all cached native file handles. func (mw *ModelWeights) ReleaseAll() { for path, native := range mw.nativeCache { diff --git a/x/imagegen/weights.go b/x/imagegen/weights.go index 377707fbf..f49c7e77e 100644 --- a/x/imagegen/weights.go +++ b/x/imagegen/weights.go @@ -106,6 +106,21 @@ func (mw *ManifestWeights) HasTensor(name string) bool { return ok } +// Quantization returns the model's quantization type from model_index.json. +// Returns empty string if not quantized or unknown. +func (mw *ManifestWeights) Quantization() string { + if mw.manifest == nil { + return "" + } + var index struct { + Quantization string `json:"quantization"` + } + if err := mw.manifest.ReadConfigJSON("model_index.json", &index); err != nil { + return "" + } + return index.Quantization +} + // ReleaseAll frees all native handles and clears the tensor cache. func (mw *ManifestWeights) ReleaseAll() { for _, sf := range mw.nativeCache {