diff --git a/x/imagegen/cache/step.go b/x/imagegen/cache/step.go index 830df447f..f91f22fa0 100644 --- a/x/imagegen/cache/step.go +++ b/x/imagegen/cache/step.go @@ -9,7 +9,7 @@ import "github.com/ollama/ollama/x/imagegen/mlx" // shallow layers change little between consecutive steps, so we can // cache their outputs and skip recomputation on non-refresh steps. // -// Supports both single-stream (Z-Image) and dual-stream (Qwen-Image) architectures: +// Supports both single-stream and dual-stream architectures: // - Single-stream: use Get/Set for the single output per layer // - Dual-stream: use Get/Set for stream 1 (imgH), Get2/Set2 for stream 2 (txtH) // @@ -87,7 +87,7 @@ func (c *StepCache) Set(layer int, arr *mlx.Array) { } // Get2 returns the cached output for a layer (stream 2), or nil if not cached. -// Used for dual-stream architectures like Qwen-Image. +// Used for dual-stream architectures. func (c *StepCache) Get2(layer int) *mlx.Array { if layer < len(c.layers2) { return c.layers2[layer] @@ -96,7 +96,7 @@ func (c *StepCache) Get2(layer int) *mlx.Array { } // Set2 stores a layer output (stream 2), freeing any previous value. -// Used for dual-stream architectures like Qwen-Image. +// Used for dual-stream architectures. func (c *StepCache) Set2(layer int, arr *mlx.Array) { if layer < len(c.layers2) { if c.layers2[layer] != nil { diff --git a/x/imagegen/cmd/engine/main.go b/x/imagegen/cmd/engine/main.go index 003be3a37..f0e705d1c 100644 --- a/x/imagegen/cmd/engine/main.go +++ b/x/imagegen/cmd/engine/main.go @@ -21,8 +21,6 @@ import ( "github.com/ollama/ollama/x/imagegen/models/gemma3" "github.com/ollama/ollama/x/imagegen/models/gpt_oss" "github.com/ollama/ollama/x/imagegen/models/llama" - "github.com/ollama/ollama/x/imagegen/models/qwen_image" - "github.com/ollama/ollama/x/imagegen/models/qwen_image_edit" "github.com/ollama/ollama/x/imagegen/models/zimage" "github.com/ollama/ollama/x/imagegen/safetensors" ) @@ -61,14 +59,11 @@ func main() { listTensors := flag.Bool("list", false, "List tensors only") cpuProfile := flag.String("cpuprofile", "", "Write CPU profile to file") gpuCapture := flag.String("gpu-capture", "", "Capture GPU trace to .gputrace file (run with MTL_CAPTURE_ENABLED=1)") - layerCache := flag.Bool("layer-cache", false, "Enable layer caching for faster diffusion (Z-Image, Qwen-Image). Not compatible with CFG/negative prompts.") wiredLimitGB := flag.Int("wired-limit", 32, "Metal wired memory limit in GB") // Legacy mode flags zimageFlag := flag.Bool("zimage", false, "Z-Image generation") flux2Flag := flag.Bool("flux2", false, "FLUX.2 Klein generation") - qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation") - qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing") var inputImages stringSlice flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)") negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)") @@ -166,60 +161,6 @@ func main() { if err == nil { err = saveImageArray(img, *out) } - case *qwenImage: - m, loadErr := qwen_image.LoadPersistent(*modelPath) - if loadErr != nil { - log.Fatal(loadErr) - } - var img *mlx.Array - img, err = m.GenerateFromConfig(&qwen_image.GenerateConfig{ - Prompt: *prompt, - NegativePrompt: *negativePrompt, - CFGScale: float32(*cfgScale), - Width: int32(*width), - Height: int32(*height), - Steps: *steps, - Seed: *seed, - LayerCache: *layerCache, - }) - if err == nil { - err = saveImageArray(img, *out) - } - case *qwenImageEdit: - if len(inputImages) == 0 { - log.Fatal("qwen-image-edit requires at least one -input-image") - } - - m, loadErr := qwen_image_edit.LoadPersistent(*modelPath) - if loadErr != nil { - log.Fatal(loadErr) - } - // For image editing, use 0 for dimensions to auto-detect from input image - // unless explicitly overridden from defaults - editWidth := int32(0) - editHeight := int32(0) - if *width != 1024 { - editWidth = int32(*width) - } - if *height != 1024 { - editHeight = int32(*height) - } - - cfg := &qwen_image_edit.GenerateConfig{ - Prompt: *prompt, - NegativePrompt: *negativePrompt, - CFGScale: float32(*cfgScale), - Width: editWidth, - Height: editHeight, - Steps: *steps, - Seed: *seed, - } - - var img *mlx.Array - img, err = m.EditFromConfig(inputImages, cfg) - if err == nil { - err = saveImageArray(img, *out) - } case *listTensors: err = listModelTensors(*modelPath) default: diff --git a/x/imagegen/models/qwen_image/pipeline_test.go b/x/imagegen/models/qwen_image/pipeline_test.go deleted file mode 100644 index 4a0ad7135..000000000 --- a/x/imagegen/models/qwen_image/pipeline_test.go +++ /dev/null @@ -1,87 +0,0 @@ -//go:build mlx - -package qwen_image - -import ( - "fmt" - "os" - "path/filepath" - "runtime" - "testing" - - "github.com/ollama/ollama/x/imagegen/mlx" -) - -// TestMain initializes MLX before running tests. -// If MLX libraries are not available, tests are skipped. -func TestMain(m *testing.M) { - // Change to repo root so ./build/lib/ollama/ path works - _, thisFile, _, _ := runtime.Caller(0) - repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..", "..") - if err := os.Chdir(repoRoot); err != nil { - fmt.Printf("Failed to change to repo root: %v\n", err) - os.Exit(1) - } - - if err := mlx.InitMLX(); err != nil { - fmt.Printf("Skipping qwen_image tests: %v\n", err) - os.Exit(0) - } - os.Exit(m.Run()) -} - -// TestPipelineOutput runs the full pipeline (integration test). -// Skips if model weights not found. Requires ~50GB VRAM. -func TestPipelineOutput(t *testing.T) { - modelPath := "../../../weights/Qwen-Image-2512" - if _, err := os.Stat(modelPath); os.IsNotExist(err) { - t.Skip("Skipping: model weights not found at " + modelPath) - } - - // Load model - pm, err := LoadPersistent(modelPath) - if err != nil { - t.Skipf("Skipping: failed to load model: %v", err) - } - - // Run 2-step pipeline (minimum for stable scheduler) - cfg := &GenerateConfig{ - Prompt: "a cat", - Width: 256, - Height: 256, - Steps: 2, - Seed: 42, - } - - output, err := pm.GenerateFromConfig(cfg) - if err != nil { - t.Fatalf("Pipeline failed: %v", err) - } - mlx.Eval(output) - - // Verify output shape [1, C, H, W] - shape := output.Shape() - if len(shape) != 4 { - t.Errorf("Expected 4D output, got %v", shape) - } - if shape[0] != 1 || shape[1] != 3 || shape[2] != cfg.Height || shape[3] != cfg.Width { - t.Errorf("Shape mismatch: got %v, expected [1, 3, %d, %d]", shape, cfg.Height, cfg.Width) - } - - // Verify values in expected range [0, 1] - data := output.Data() - minVal, maxVal := float32(1.0), float32(0.0) - for _, v := range data { - if v < minVal { - minVal = v - } - if v > maxVal { - maxVal = v - } - } - t.Logf("Output range: [%.4f, %.4f]", minVal, maxVal) - - if minVal < -0.1 || maxVal > 1.1 { - t.Errorf("Output values out of range: [%.4f, %.4f]", minVal, maxVal) - } -} diff --git a/x/imagegen/models/qwen_image/qwen25vl.go b/x/imagegen/models/qwen_image/qwen25vl.go deleted file mode 100644 index af519ee7d..000000000 --- a/x/imagegen/models/qwen_image/qwen25vl.go +++ /dev/null @@ -1,1802 +0,0 @@ -//go:build mlx - -package qwen_image - -import ( - "errors" - "fmt" - "math" - "path/filepath" - - "github.com/ollama/ollama/x/imagegen/mlx" - "github.com/ollama/ollama/x/imagegen/safetensors" - "github.com/ollama/ollama/x/imagegen/tokenizer" -) - -// Qwen25VLConfig holds Qwen2.5-VL configuration -type Qwen25VLConfig struct { - // Text model config - HiddenSize int32 `json:"hidden_size"` // 3584 - NumHiddenLayers int32 `json:"num_hidden_layers"` // 28 - IntermediateSize int32 `json:"intermediate_size"` // 18944 - NumAttentionHeads int32 `json:"num_attention_heads"` // 28 - NumKeyValueHeads int32 `json:"num_key_value_heads"` // 4 - VocabSize int32 `json:"vocab_size"` // 152064 - RMSNormEps float32 `json:"rms_norm_eps"` // 1e-6 - RopeTheta float32 `json:"rope_theta"` // 1000000 - HeadDim int32 // Calculated: HiddenSize / NumAttentionHeads - MRoPESection []int32 // [16, 24, 24] for temporal, height, width - - // Vision config - VisionHiddenSize int32 `json:"vision_hidden_size"` // 1280 - VisionNumLayers int32 `json:"vision_num_layers"` // 32 - VisionNumHeads int32 `json:"vision_num_heads"` // 16 - VisionIntermSize int32 `json:"vision_intermediate"` // 3420 - VisionPatchSize int32 `json:"vision_patch_size"` // 14 - VisionOutHiddenSize int32 `json:"vision_out_hidden"` // 3584 - VisionSpatialMerge int32 `json:"vision_spatial_merge"` // 2 - VisionWindowSize int32 `json:"vision_window_size"` // 112 - VisionFullAttIdx []int32 // [7, 15, 23, 31] - - // Special tokens - ImageTokenID int32 // 151655 - VisionStartTokenID int32 // 151652 - VisionEndTokenID int32 // 151653 -} - -// defaultQwen25VLConfig returns default config -func defaultQwen25VLConfig() *Qwen25VLConfig { - cfg := &Qwen25VLConfig{ - // Text - HiddenSize: 3584, - NumHiddenLayers: 28, - IntermediateSize: 18944, - NumAttentionHeads: 28, - NumKeyValueHeads: 4, - VocabSize: 152064, - RMSNormEps: 1e-6, - RopeTheta: 1000000, - MRoPESection: []int32{16, 24, 24}, - - // Vision - VisionHiddenSize: 1280, - VisionNumLayers: 32, - VisionNumHeads: 16, - VisionIntermSize: 3420, - VisionPatchSize: 14, - VisionOutHiddenSize: 3584, - VisionSpatialMerge: 2, - VisionWindowSize: 112, - VisionFullAttIdx: []int32{7, 15, 23, 31}, - - // Special tokens - ImageTokenID: 151655, - VisionStartTokenID: 151652, - VisionEndTokenID: 153653, - } - cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads - return cfg -} - -// Qwen25VL is the Qwen2.5-VL vision-language encoder -type Qwen25VL struct { - Config *Qwen25VLConfig - - // Text model - Embedding *mlx.Array - Blocks []*VLTextBlock - FinalNorm *mlx.Array - - // Vision tower (optional - nil for text-only models) - VisionPatchEmbed *VisionPatchEmbed - VisionBlocks []*VisionBlock - VisionMerger *VisionMerger - HasVision bool // True if vision tower is loaded -} - -// LoadTextOnly loads only the text encoder components (skips vision tower) -// Use this for text-to-image generation where vision components are not needed -func (m *Qwen25VL) LoadTextOnly(path string) error { - return m.load(path, false) -} - -// Load loads the vision-language encoder from a directory -// Vision components are loaded if weights exist -func (m *Qwen25VL) Load(path string) error { - return m.load(path, true) -} - -// load is the internal loading function -func (m *Qwen25VL) load(path string, loadVision bool) error { - fmt.Println("Loading Qwen2.5-VL encoder...") - - cfg := defaultQwen25VLConfig() - m.Config = cfg - - weights, err := safetensors.LoadModelWeights(path) - if err != nil { - return fmt.Errorf("weights: %w", err) - } - - // Bulk load all weights as bf16 - fmt.Print(" Loading weights as bf16... ") - if err := weights.Load(mlx.DtypeBFloat16); err != nil { - return fmt.Errorf("failed to load weights: %w", err) - } - fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024)) - - // Load text embedding - fmt.Print(" Loading text embeddings... ") - embedding, err := weights.Get("model.embed_tokens.weight") - if err != nil { - return err - } - m.Embedding = embedding - fmt.Printf("✓ [%v]\n", embedding.Shape()) - - // Load text blocks - m.Blocks = make([]*VLTextBlock, cfg.NumHiddenLayers) - for i := int32(0); i < cfg.NumHiddenLayers; i++ { - fmt.Printf("\r Loading text blocks... %d/%d", i+1, cfg.NumHiddenLayers) - block, err := newVLTextBlock(weights, int(i), cfg) - if err != nil { - return fmt.Errorf("failed to load text block %d: %w", i, err) - } - m.Blocks[i] = block - } - fmt.Printf("\r Loading text blocks... ✓ [%d blocks] \n", cfg.NumHiddenLayers) - - // Load final norm - fmt.Print(" Loading final norm... ") - finalNorm, err := weights.Get("model.norm.weight") - if err != nil { - return err - } - m.FinalNorm = finalNorm - fmt.Println("✓") - - // Try to load vision tower (optional) - m.HasVision = false - if loadVision { - if _, err := weights.Get("visual.patch_embed.proj.weight"); err == nil { - fmt.Print(" Loading vision patch embed... ") - m.VisionPatchEmbed, err = newVisionPatchEmbed(weights, cfg) - if err != nil { - return fmt.Errorf("vision patch embed: %w", err) - } - fmt.Println("✓") - - m.VisionBlocks = make([]*VisionBlock, cfg.VisionNumLayers) - for i := int32(0); i < cfg.VisionNumLayers; i++ { - fmt.Printf("\r Loading vision blocks... %d/%d", i+1, cfg.VisionNumLayers) - block, err := newVisionBlock(weights, int(i), cfg) - if err != nil { - return fmt.Errorf("failed to load vision block %d: %w", i, err) - } - m.VisionBlocks[i] = block - } - fmt.Printf("\r Loading vision blocks... ✓ [%d blocks] \n", cfg.VisionNumLayers) - - fmt.Print(" Loading vision merger... ") - m.VisionMerger, err = newVisionMerger(weights, cfg) - if err != nil { - return fmt.Errorf("vision merger: %w", err) - } - fmt.Println("✓") - - m.HasVision = true - } else { - fmt.Println(" (No vision tower - text-only mode)") - } - } else { - fmt.Println(" (Skipping vision tower)") - } - - weights.ReleaseAll() - return nil -} - -// EncodePrompt encodes a text prompt for image generation (text-only mode) -// Uses the Qwen-Image template and drops the first 34 tokens (system prefix) -func (m *Qwen25VL) EncodePrompt(tok *tokenizer.Tokenizer, prompt string) *mlx.Array { - cfg := m.Config - - // Template from Python: prompt_template_encode (for image generation) - template := "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n%s<|im_end|>\n<|im_start|>assistant\n" - formattedPrompt := fmt.Sprintf(template, prompt) - - // Tokenize - tokens := tok.Encode(formattedPrompt, false) - - // Create token array - seqLen := int32(len(tokens)) - tokenArr := mlx.NewArrayInt32(tokens, []int32{1, seqLen}) - - // Get text embeddings - textEmbed := mlx.EmbeddingLookup(m.Embedding, tokenArr) - - // Compute RoPE - cossin := m.computeTextRoPE(seqLen, 1) - - // Forward through ALL text blocks - x := textEmbed - for _, block := range m.Blocks { - x = block.Forward(x, cossin) - } - - // Apply final norm - x = mlx.RMSNorm(x, m.FinalNorm, cfg.RMSNormEps) - - // Drop first 34 tokens (system prefix) - // prompt_template_encode_start_idx = 34 - dropIdx := int32(34) - if x.Shape()[1] > dropIdx { - x = mlx.Slice(x, []int32{0, dropIdx, 0}, []int32{1, x.Shape()[1], cfg.HiddenSize}) - } - - return x -} - -// EncodePromptWithImage encodes a text prompt with an image -// Returns: embeddings [B, L, hidden_size], mask [B, L], error -func (m *Qwen25VL) EncodePromptWithImage(tok *tokenizer.Tokenizer, prompt string, image *mlx.Array) (*mlx.Array, *mlx.Array, error) { - if !m.HasVision { - return nil, nil, errors.New("EncodePromptWithImage called on text-only model") - } - - cfg := m.Config - - // Template from Python diffusers pipeline: prompt_template_encode - // Python's _get_qwen_prompt_embeds adds "Picture 1: " before vision tokens - template := "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\nPicture 1: <|vision_start|><|image_pad|><|vision_end|>%s<|im_end|>\n<|im_start|>assistant\n" - formattedPrompt := fmt.Sprintf(template, prompt) - - // Tokenize - tokens := tok.Encode(formattedPrompt, false) - - // Process vision if image provided - var visionEmbeddings *mlx.Array - var numImageTokens int32 - var visionH, visionW int32 // Grid dims in patches (before spatial merge) - if image != nil { - visionEmbeddings = m.encodeVision(image) - numImageTokens = visionEmbeddings.Shape()[1] - // Get original grid dimensions from image shape - imgShape := image.Shape() - visionH = imgShape[2] / cfg.VisionPatchSize // Height in patches - visionW = imgShape[3] / cfg.VisionPatchSize // Width in patches - } - - // Find image token position and expand - expandedTokens := make([]int32, 0, len(tokens)+int(numImageTokens)) - imageTokenPos := int32(-1) - textAfterCount := int32(0) - for i, t := range tokens { - if t == cfg.ImageTokenID { - imageTokenPos = int32(len(expandedTokens)) - // Insert placeholder tokens for image - for j := int32(0); j < numImageTokens; j++ { - expandedTokens = append(expandedTokens, cfg.ImageTokenID) - } - // Count remaining tokens after image - textAfterCount = int32(len(tokens) - i - 1) - } else { - expandedTokens = append(expandedTokens, t) - } - } - - // Create token array - seqLen := int32(len(expandedTokens)) - tokenArr := mlx.NewArrayInt32(expandedTokens, []int32{1, seqLen}) - - // Get text embeddings - textEmbed := mlx.EmbeddingLookup(m.Embedding, tokenArr) // [1, L, hidden] - - // Replace image token embeddings with vision embeddings - if visionEmbeddings != nil && imageTokenPos >= 0 { - // Split, replace, concat - before := mlx.Slice(textEmbed, []int32{0, 0, 0}, []int32{1, imageTokenPos, cfg.HiddenSize}) - after := mlx.Slice(textEmbed, []int32{0, imageTokenPos + numImageTokens, 0}, []int32{1, seqLen, cfg.HiddenSize}) - textEmbed = mlx.Concatenate([]*mlx.Array{before, visionEmbeddings, after}, 1) - } - - // Compute RoPE - use multimodal RoPE when image is present - var cossin [2]*mlx.Array - if image != nil && imageTokenPos >= 0 { - cossin = m.ComputeMultimodalRoPE(imageTokenPos, visionH, visionW, textAfterCount, cfg.VisionSpatialMerge) - } else { - cossin = m.computeTextRoPE(seqLen, 1) - } - - // Forward through ALL text blocks - // Python uses hidden_states[-1] (LAST layer output, not second-to-last!) - x := textEmbed - for _, block := range m.Blocks { - x = block.Forward(x, cossin) - } - - // Apply final norm (Python DOES apply this for the output) - x = mlx.RMSNorm(x, m.FinalNorm, cfg.RMSNormEps) - - // Drop first N tokens (system prefix) - // prompt_template_encode_start_idx = 64 - dropIdx := int32(64) - if x.Shape()[1] > dropIdx { - x = mlx.Slice(x, []int32{0, dropIdx, 0}, []int32{1, x.Shape()[1], cfg.HiddenSize}) - } - - // Create attention mask (all ones for now) - mask := mlx.Ones(1, x.Shape()[1]) - - return x, mask, nil -} - -// EncodeVision encodes an image through the vision tower (exported for testing) -// image: [B, C, H, W] normalized image tensor -// Returns: [B, num_tokens, hidden_size] vision embeddings -func (m *Qwen25VL) EncodeVision(image *mlx.Array) *mlx.Array { - return m.encodeVision(image) -} - -// VisionRegion describes where vision embeddings are inserted in the sequence -type VisionRegion struct { - StartPos int32 // Position in sequence where vision tokens start - NumTokens int32 // Number of vision tokens - GridH int32 // Vision grid height (in patches, after spatial merge) - GridW int32 // Vision grid width (in patches, after spatial merge) -} - -// EncodePromptWithImages encodes a text prompt with multiple images -// Returns: embeddings [B, L, hidden_size], mask [B, L], regions []VisionRegion, error -func (m *Qwen25VL) EncodePromptWithImages(tok *tokenizer.Tokenizer, prompt string, images []*mlx.Array) (*mlx.Array, *mlx.Array, []VisionRegion, error) { - if !m.HasVision { - return nil, nil, nil, errors.New("EncodePromptWithImages called on text-only model") - } - if len(images) == 0 { - return nil, nil, nil, errors.New("EncodePromptWithImages called with no images") - } - - cfg := m.Config - - // Build image prompt prefix: "Picture 1: ...Picture N: ..." - imgPromptTemplate := "Picture %d: <|vision_start|><|image_pad|><|vision_end|>" - imgPrompt := "" - for i := range images { - imgPrompt += fmt.Sprintf(imgPromptTemplate, i+1) - } - - // Template from Python diffusers pipeline: prompt_template_encode - template := "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n%s%s<|im_end|>\n<|im_start|>assistant\n" - formattedPrompt := fmt.Sprintf(template, imgPrompt, prompt) - - // Tokenize - tokens := tok.Encode(formattedPrompt, false) - - // Process each image through vision tower - visionEmbeddings := make([]*mlx.Array, len(images)) - numImageTokens := make([]int32, len(images)) - visionGridH := make([]int32, len(images)) - visionGridW := make([]int32, len(images)) - - for i, image := range images { - visionEmbeddings[i] = m.encodeVision(image) - numImageTokens[i] = visionEmbeddings[i].Shape()[1] - // Get original grid dimensions from image shape - imgShape := image.Shape() - visionH := imgShape[2] / cfg.VisionPatchSize // Height in patches - visionW := imgShape[3] / cfg.VisionPatchSize // Width in patches - // After spatial merge, grid is halved - visionGridH[i] = visionH / cfg.VisionSpatialMerge - visionGridW[i] = visionW / cfg.VisionSpatialMerge - } - - // Find all image token positions and expand tokens - expandedTokens := make([]int32, 0, len(tokens)+int(sum(numImageTokens))) - imagePositions := make([]int32, 0, len(images)) // Start position for each image's tokens - imageIdx := 0 - - for _, t := range tokens { - if t == cfg.ImageTokenID { - if imageIdx < len(images) { - imagePositions = append(imagePositions, int32(len(expandedTokens))) - // Insert placeholder tokens for this image - for j := int32(0); j < numImageTokens[imageIdx]; j++ { - expandedTokens = append(expandedTokens, cfg.ImageTokenID) - } - imageIdx++ - } - } else { - expandedTokens = append(expandedTokens, t) - } - } - - // Create token array - seqLen := int32(len(expandedTokens)) - tokenArr := mlx.NewArrayInt32(expandedTokens, []int32{1, seqLen}) - - // Get text embeddings - textEmbed := mlx.EmbeddingLookup(m.Embedding, tokenArr) // [1, L, hidden] - - // Replace image token embeddings with vision embeddings - // Build list of segments to concatenate - segments := make([]*mlx.Array, 0, len(images)*2+1) - regions := make([]VisionRegion, len(images)) - lastEnd := int32(0) - - for i, imgPos := range imagePositions { - // Text segment before this image - if imgPos > lastEnd { - segments = append(segments, mlx.Slice(textEmbed, []int32{0, lastEnd, 0}, []int32{1, imgPos, cfg.HiddenSize})) - } - // Vision embeddings for this image - segments = append(segments, visionEmbeddings[i]) - regions[i] = VisionRegion{ - StartPos: imgPos, - NumTokens: numImageTokens[i], - GridH: visionGridH[i], - GridW: visionGridW[i], - } - lastEnd = imgPos + numImageTokens[i] - } - // Remaining text after last image - if lastEnd < seqLen { - segments = append(segments, mlx.Slice(textEmbed, []int32{0, lastEnd, 0}, []int32{1, seqLen, cfg.HiddenSize})) - } - - // Concatenate all segments - textEmbed = mlx.Concatenate(segments, 1) - - // Compute RoPE - use multimodal RoPE for multiple images - cossin, err := m.ComputeMultiImageRoPE(imagePositions, visionGridH, visionGridW, numImageTokens, seqLen) - if err != nil { - return nil, nil, nil, fmt.Errorf("computing RoPE: %w", err) - } - - // Forward through ALL text blocks - x := textEmbed - for _, block := range m.Blocks { - x = block.Forward(x, cossin) - } - - // Apply final norm - x = mlx.RMSNorm(x, m.FinalNorm, cfg.RMSNormEps) - - // Drop first N tokens (system prefix) - // prompt_template_encode_start_idx = 64 - dropIdx := int32(64) - if x.Shape()[1] > dropIdx { - x = mlx.Slice(x, []int32{0, dropIdx, 0}, []int32{1, x.Shape()[1], cfg.HiddenSize}) - // Adjust region positions - for i := range regions { - regions[i].StartPos -= dropIdx - } - } - - // Create attention mask (all ones) - mask := mlx.Ones(1, x.Shape()[1]) - - return x, mask, regions, nil -} - -// sum returns the sum of int32 slice -func sum(arr []int32) int32 { - var s int32 - for _, v := range arr { - s += v - } - return s -} - -// EncodeTextOnly encodes text tokens through all text blocks (exported for testing) -// tokens: array of token IDs -// Returns: [B, L, hidden_size] text embeddings after all blocks -func (m *Qwen25VL) EncodeTextOnly(tokens []int32) *mlx.Array { - seqLen := int32(len(tokens)) - tokenArr := mlx.NewArrayInt32(tokens, []int32{1, seqLen}) - - // Get text embeddings - textEmbed := mlx.EmbeddingLookup(m.Embedding, tokenArr) // [1, L, hidden] - - // Compute RoPE - cossin := m.computeTextRoPE(seqLen, 1) - - // Forward through ALL text blocks (unlike Encode which stops at second-to-last) - x := textEmbed - for _, block := range m.Blocks { - x = block.Forward(x, cossin) - } - - // Apply final norm - x = mlx.RMSNorm(x, m.FinalNorm, m.Config.RMSNormEps) - - return x -} - -// encodeVision encodes an image through the vision tower -// image: [B, C, H, W] normalized image tensor -// Returns: [B, num_tokens, hidden_size] vision embeddings -func (m *Qwen25VL) encodeVision(image *mlx.Array) *mlx.Array { - cfg := m.Config - - // Calculate grid dimensions from image - imgShape := image.Shape() - imgH := imgShape[2] - imgW := imgShape[3] - pH := imgH / cfg.VisionPatchSize // grid height in patches - pW := imgW / cfg.VisionPatchSize // grid width in patches - - // Patch embed - x := m.VisionPatchEmbed.Forward(image) - mlx.Eval(x) - - // Get window reordering info - winInfo := m.getWindowInfo(pH, pW) - - // Compute vision RoPE embeddings (already in 2x2-block order) - posEmb := m.computeVisionRoPE(pH, pW) - - shape := x.Shape() - B := shape[0] - L := shape[1] // num patches = pH * pW - D := shape[2] - spatialMergeUnit := winInfo.SpatialMergeUnit - spatialMerge := cfg.VisionSpatialMerge - - // Convert patch embed from row-major to 2x2-block order - // Row-major: (0,0), (0,1), (0,2), ..., (1,0), (1,1), ... - // 2x2-block: (0,0), (0,1), (1,0), (1,1), (0,2), (0,3), (1,2), (1,3), ... - llmGridH := pH / spatialMerge - llmGridW := pW / spatialMerge - blockReorderIdx := make([]int32, L) - idx := int32(0) - for hBlock := int32(0); hBlock < llmGridH; hBlock++ { - for wBlock := int32(0); wBlock < llmGridW; wBlock++ { - for dh := int32(0); dh < spatialMerge; dh++ { - for dw := int32(0); dw < spatialMerge; dw++ { - h := hBlock*spatialMerge + dh - w := wBlock*spatialMerge + dw - rowMajorIdx := h*pW + w - blockReorderIdx[idx] = rowMajorIdx - idx++ - } - } - } - } - blockIdxArr := mlx.NewArrayInt32(blockReorderIdx, []int32{L}) - x = mlx.Take(x, blockIdxArr, 1) // Reorder patches to 2x2-block order - - // Window reorder hidden states and RoPE before blocks - // Python: reshape to [L/4, 4, D], reorder dim 0, reshape back - // Reshape x: [B, L, D] -> [B, L/4, 4, D] - x = mlx.Reshape(x, B, L/spatialMergeUnit, spatialMergeUnit, D) - // Reorder using window index - winIdxArr := mlx.NewArrayInt32(winInfo.WindowIndex, []int32{int32(len(winInfo.WindowIndex))}) - x = mlx.Take(x, winIdxArr, 1) // Take along axis 1 - // Reshape back: [B, L/4, 4, D] -> [B, L, D] - x = mlx.Reshape(x, B, L, D) - - // Similarly reorder RoPE: [L, headDim] -> [L/4, 4, headDim] -> reorder -> [L, headDim] - cosShape := posEmb[0].Shape() - ropeL := cosShape[0] - ropeD := cosShape[1] - cos := mlx.Reshape(posEmb[0], ropeL/spatialMergeUnit, spatialMergeUnit, ropeD) - sin := mlx.Reshape(posEmb[1], ropeL/spatialMergeUnit, spatialMergeUnit, ropeD) - cos = mlx.Take(cos, winIdxArr, 0) - sin = mlx.Take(sin, winIdxArr, 0) - cos = mlx.Reshape(cos, ropeL, ropeD) - sin = mlx.Reshape(sin, ropeL, ropeD) - posEmb = [2]*mlx.Array{cos, sin} - - // Materialize to prevent freeing during block evaluations - mlx.Eval(x, posEmb[0], posEmb[1]) - - // Full sequence cu_seqlens for full attention blocks - cuSeqlensFull := []int32{0, L} - - // Vision blocks - use window attention except at full attention indices - for i, block := range m.VisionBlocks { - useFullAttention := false - for _, idx := range cfg.VisionFullAttIdx { - if int32(i) == idx { - useFullAttention = true - break - } - } - - var cuSeqlens []int32 - if useFullAttention { - cuSeqlens = cuSeqlensFull - } else { - cuSeqlens = winInfo.CuWindowSeqlens - } - - x = block.Forward(x, posEmb, cuSeqlens) - } - - // Spatial merge (2x2 -> 1) - x = m.VisionMerger.ForwardWithDims(x, pH, pW) - - // Reverse window reorder after merger - revIdxArr := mlx.NewArrayInt32(winInfo.ReverseIndex, []int32{int32(len(winInfo.ReverseIndex))}) - x = mlx.Take(x, revIdxArr, 1) - - return x -} - -// WindowInfo holds window reordering and attention boundary info -type WindowInfo struct { - WindowIndex []int32 // Reordering indices for merged tokens - ReverseIndex []int32 // Reverse reordering indices - CuWindowSeqlens []int32 // Cumulative window boundaries in UNMERGED sequence - SpatialMergeUnit int32 // Number of patches per merged token (4 = 2x2) -} - -// getWindowInfo computes window reordering indices and attention boundaries -// pH, pW: patch grid dimensions before 2x2 merge -func (m *Qwen25VL) getWindowInfo(pH, pW int32) *WindowInfo { - cfg := m.Config - spatialMergeUnit := cfg.VisionSpatialMerge * cfg.VisionSpatialMerge // 4 - - // After 2x2 merge - llmGridH := pH / cfg.VisionSpatialMerge - llmGridW := pW / cfg.VisionSpatialMerge - numTokens := llmGridH * llmGridW - - // Window size in merged tokens - // window_size=112, spatial_merge_size=2, patch_size=14 - // vit_merger_window_size = 112 / 2 / 14 = 4 - vitMergerWindowSize := cfg.VisionWindowSize / cfg.VisionSpatialMerge / cfg.VisionPatchSize - - // Calculate padding and number of windows - padH := vitMergerWindowSize - llmGridH%vitMergerWindowSize - if padH == vitMergerWindowSize { - padH = 0 - } - padW := vitMergerWindowSize - llmGridW%vitMergerWindowSize - if padW == vitMergerWindowSize { - padW = 0 - } - - numWindowsH := (llmGridH + padH) / vitMergerWindowSize - numWindowsW := (llmGridW + padW) / vitMergerWindowSize - - // Create padded grid with -1 for padding - paddedH := llmGridH + padH - paddedW := llmGridW + padW - grid := make([]int32, paddedH*paddedW) - for i := range grid { - grid[i] = -1 - } - for h := int32(0); h < llmGridH; h++ { - for w := int32(0); w < llmGridW; w++ { - grid[h*paddedW+w] = h*llmGridW + w - } - } - - // Reorder into windows and track window sizes - windowIndex := make([]int32, 0, numTokens) - windowSizes := make([]int32, 0, numWindowsH*numWindowsW) - ws := vitMergerWindowSize - - for wh := int32(0); wh < numWindowsH; wh++ { - for ww := int32(0); ww < numWindowsW; ww++ { - windowStart := len(windowIndex) - // Extract window - for h := int32(0); h < ws; h++ { - for w := int32(0); w < ws; w++ { - idx := (wh*ws+h)*paddedW + (ww*ws + w) - if grid[idx] >= 0 { - windowIndex = append(windowIndex, grid[idx]) - } - } - } - windowSize := int32(len(windowIndex) - windowStart) - windowSizes = append(windowSizes, windowSize) - } - } - - // Create reverse index (argsort of windowIndex) - reverseIndex := make([]int32, numTokens) - for i, idx := range windowIndex { - reverseIndex[idx] = int32(i) - } - - // Compute cumulative sequence lengths in UNMERGED sequence - // Each merged token corresponds to spatialMergeUnit patches - cuWindowSeqlens := make([]int32, len(windowSizes)+1) - cuWindowSeqlens[0] = 0 - for i, size := range windowSizes { - cuWindowSeqlens[i+1] = cuWindowSeqlens[i] + size*spatialMergeUnit - } - - return &WindowInfo{ - WindowIndex: windowIndex, - ReverseIndex: reverseIndex, - CuWindowSeqlens: cuWindowSeqlens, - SpatialMergeUnit: spatialMergeUnit, - } -} - -// ComputeMultiImageRoPE computes M-RoPE for combined text + multiple vision regions + text sequences -// This extends ComputeMultimodalRoPE to handle N images instead of just one. -// -// Parameters: -// - imagePositions: starting position of each image's tokens in the sequence -// - visionGridH, visionGridW: grid dimensions for each image (after spatial merge) -// - numImageTokens: number of tokens for each image -// - totalLen: total sequence length -func (m *Qwen25VL) ComputeMultiImageRoPE(imagePositions []int32, visionGridH, visionGridW, numImageTokens []int32, totalLen int32) ([2]*mlx.Array, error) { - numImages := len(imagePositions) - - // Build 3D position IDs: [3, 1, totalLen] - // Dimension 0: temporal, Dimension 1: height, Dimension 2: width - posIDs := make([]float32, 3*totalLen) - - // Process sequence in order - stIdx := int32(0) // Running text position counter - seqIdx := int32(0) - - for i := 0; i < numImages; i++ { - imgPos := imagePositions[i] - gridH := visionGridH[i] - gridW := visionGridW[i] - numTokens := numImageTokens[i] - - // Text segment before this image - for seqIdx < imgPos { - posIDs[0*totalLen+seqIdx] = float32(stIdx) - posIDs[1*totalLen+seqIdx] = float32(stIdx) - posIDs[2*totalLen+seqIdx] = float32(stIdx) - stIdx++ - seqIdx++ - } - - // Vision tokens for this image - // Python uses stIdx as base offset for all position dimensions - for h := int32(0); h < gridH; h++ { - for w := int32(0); w < gridW; w++ { - posIDs[0*totalLen+seqIdx] = float32(stIdx) // temporal: constant = stIdx - posIDs[1*totalLen+seqIdx] = float32(stIdx + h) // height: stIdx + row_index - posIDs[2*totalLen+seqIdx] = float32(stIdx + w) // width: stIdx + col_index - seqIdx++ - } - } - - // Verify we processed the expected number of tokens - if seqIdx != imgPos+numTokens { - return [2]*mlx.Array{}, fmt.Errorf("mismatch: processed %d but expected %d tokens for image %d", seqIdx-imgPos, numTokens, i) - } - - // Update stIdx for next text segment: max(temporal, height, width) + 1 - maxVisionPos := stIdx // temporal max - if stIdx+gridH-1 > maxVisionPos { - maxVisionPos = stIdx + gridH - 1 - } - if stIdx+gridW-1 > maxVisionPos { - maxVisionPos = stIdx + gridW - 1 - } - stIdx = maxVisionPos + 1 - } - - // Text after last image - for seqIdx < totalLen { - posIDs[0*totalLen+seqIdx] = float32(stIdx) - posIDs[1*totalLen+seqIdx] = float32(stIdx) - posIDs[2*totalLen+seqIdx] = float32(stIdx) - stIdx++ - seqIdx++ - } - - posIDsArr := mlx.NewArray(posIDs, []int32{3, 1, totalLen}) - return m.computeRoPEFromPositions(posIDsArr, totalLen, 1), nil -} - -// computeTextRoPE computes M-RoPE for text-only sequences -func (m *Qwen25VL) computeTextRoPE(L, B int32) [2]*mlx.Array { - // For text-only, all 3 dims use same positions [0, 1, 2, ..., L-1] - posArr := make([]float32, L*3) - for d := 0; d < 3; d++ { - for i := int32(0); i < L; i++ { - posArr[int32(d)*L+i] = float32(i) - } - } - posIDs := mlx.NewArray(posArr, []int32{3, 1, L}) - posIDs = mlx.Tile(posIDs, []int32{1, B, 1}) - return m.computeRoPEFromPositions(posIDs, L, B) -} - -// ComputeMultimodalRoPE computes M-RoPE for combined text + vision + text sequences -// This matches Python's get_rope_index behavior exactly. -// Exported for testing. -// -// Python pattern discovered from testing: -// -// Vision row 1: temporal=stIdx, height=stIdx, width=[stIdx, stIdx+1, ..., stIdx+gridW-1] -// Vision row 2: temporal=stIdx, height=stIdx+1, width=[stIdx, stIdx+1, ..., stIdx+gridW-1] -// Text after: temporal=stIdx+1+i, height=stIdx+gridH+i, width=stIdx+gridW+i -func (m *Qwen25VL) ComputeMultimodalRoPE(textBefore, visionH, visionW, textAfter int32, spatialMerge int32) [2]*mlx.Array { - // Vision grid after spatial merge - llmGridH := visionH / spatialMerge - llmGridW := visionW / spatialMerge - visionLen := llmGridH * llmGridW - totalLen := textBefore + visionLen + textAfter - - // Build 3D position IDs: [3, 1, totalLen] - // Dimension 0: temporal, Dimension 1: height, Dimension 2: width - posIDs := make([]float32, 3*totalLen) - - // Text before vision: all dims same [0, 1, 2, ..., textBefore-1] - for d := 0; d < 3; d++ { - for i := int32(0); i < textBefore; i++ { - posIDs[int32(d)*totalLen+i] = float32(i) - } - } - - // Vision tokens: 3D grid positions - // Python uses stIdx (textBefore) as base offset for all position dimensions - stIdx := textBefore - for h := int32(0); h < llmGridH; h++ { - for w := int32(0); w < llmGridW; w++ { - idx := stIdx + h*llmGridW + w - posIDs[0*totalLen+idx] = float32(stIdx) // temporal: constant = stIdx - posIDs[1*totalLen+idx] = float32(stIdx + h) // height: stIdx + row_index - posIDs[2*totalLen+idx] = float32(stIdx + w) // width: stIdx + col_index - } - } - - // Text after vision: ALL dimensions continue from max(temporal, height, width) + 1 - // max is max(stIdx, stIdx+llmGridH-1, stIdx+llmGridW-1) = stIdx + max(0, llmGridH-1, llmGridW-1) - // Then st_idx = max + 1 - maxVisionPos := stIdx // temporal max - if stIdx+llmGridH-1 > maxVisionPos { - maxVisionPos = stIdx + llmGridH - 1 - } - if stIdx+llmGridW-1 > maxVisionPos { - maxVisionPos = stIdx + llmGridW - 1 - } - textAfterStart := maxVisionPos + 1 - for i := int32(0); i < textAfter; i++ { - seqIdx := textBefore + visionLen + i - posIDs[0*totalLen+seqIdx] = float32(textAfterStart + i) // temporal - posIDs[1*totalLen+seqIdx] = float32(textAfterStart + i) // height - posIDs[2*totalLen+seqIdx] = float32(textAfterStart + i) // width - } - - posIDsArr := mlx.NewArray(posIDs, []int32{3, 1, totalLen}) - return m.computeRoPEFromPositions(posIDsArr, totalLen, 1) -} - -// computeRoPEFromPositions computes cos/sin from 3D position IDs -// posIDs: [3, B, L] where dim 0 is temporal, 1 is height, 2 is width -func (m *Qwen25VL) computeRoPEFromPositions(posIDs *mlx.Array, L, B int32) [2]*mlx.Array { - cfg := m.Config - half := cfg.HeadDim / 2 - - // Compute inv_freq - invFreqArr := make([]float32, half) - for i := int32(0); i < half; i++ { - invFreqArr[i] = float32(1.0 / math.Pow(float64(cfg.RopeTheta), 2.0*float64(i)/float64(cfg.HeadDim))) - } - invFreq := mlx.NewArray(invFreqArr, []int32{half}) - - // Process each position dimension - var cosAll, sinAll []*mlx.Array - for d := int32(0); d < 3; d++ { - // Get positions for this dimension: [B, L] - pos := mlx.Slice(posIDs, []int32{d, 0, 0}, []int32{d + 1, B, L}) - pos = mlx.Squeeze(pos, 0) // [B, L] - - posExp := mlx.ExpandDims(pos, 2) // [B, L, 1] - invFreqExp := mlx.Reshape(invFreq, 1, 1, half) // [1, 1, half] - freqs := mlx.Mul(posExp, invFreqExp) // [B, L, half] - emb := mlx.Tile(freqs, []int32{1, 1, 2}) // [B, L, D] - - cosAll = append(cosAll, mlx.ExpandDims(mlx.Cos(emb), 0)) - sinAll = append(sinAll, mlx.ExpandDims(mlx.Sin(emb), 0)) - } - - cos := mlx.Concatenate(cosAll, 0) // [3, B, L, D] - sin := mlx.Concatenate(sinAll, 0) - - return [2]*mlx.Array{cos, sin} -} - -// computeVisionRoPE computes RoPE embeddings for vision patches -// pH, pW: grid dimensions in patches -// Returns: [2]*mlx.Array containing (cos, sin) each of shape [numPatches, headDim] -func (m *Qwen25VL) computeVisionRoPE(pH, pW int32) [2]*mlx.Array { - cfg := m.Config - headDim := cfg.VisionHiddenSize / cfg.VisionNumHeads // 80 for 1280/16 - halfDim := headDim / 2 // 40 - quarterDim := halfDim / 2 // 20 - spatialMerge := cfg.VisionSpatialMerge // 2 - - // Python Qwen2_5_VisionRotaryEmbedding uses dim=head_dim/2=40 - // inv_freq = 1.0 / (theta ** (arange(0, dim, 2) / dim)) -> 20 elements - theta := float64(10000.0) - invFreqArr := make([]float32, quarterDim) - for i := int32(0); i < quarterDim; i++ { - invFreqArr[i] = float32(1.0 / math.Pow(theta, float64(2*i)/float64(halfDim))) - } - invFreq := mlx.NewArray(invFreqArr, []int32{quarterDim}) - - // Create position IDs matching Python's 2x2 block ordering: - // Python does: reshape(h//2, 2, w//2, 2), permute(0, 2, 1, 3), flatten - // This groups patches by 2x2 merged token blocks - numPatches := pH * pW - hPosArr := make([]float32, numPatches) - wPosArr := make([]float32, numPatches) - - // Number of merged token blocks - llmGridH := pH / spatialMerge - llmGridW := pW / spatialMerge - - idx := int32(0) - for hBlock := int32(0); hBlock < llmGridH; hBlock++ { - for wBlock := int32(0); wBlock < llmGridW; wBlock++ { - // Within each 2x2 block: (0,0), (0,1), (1,0), (1,1) - for dh := int32(0); dh < spatialMerge; dh++ { - for dw := int32(0); dw < spatialMerge; dw++ { - h := hBlock*spatialMerge + dh - w := wBlock*spatialMerge + dw - hPosArr[idx] = float32(h) - wPosArr[idx] = float32(w) - idx++ - } - } - } - } - - hPos := mlx.NewArray(hPosArr, []int32{numPatches, 1}) - wPos := mlx.NewArray(wPosArr, []int32{numPatches, 1}) - invFreqExp := mlx.Reshape(invFreq, 1, quarterDim) - - // Compute freqs: [numPatches, quarterDim] for each of h and w - hFreqs := mlx.Mul(hPos, invFreqExp) // [L, 20] - wFreqs := mlx.Mul(wPos, invFreqExp) // [L, 20] - - // Concatenate h and w freqs: [numPatches, halfDim] = [L, 40] - freqs := mlx.Concatenate([]*mlx.Array{hFreqs, wFreqs}, 1) - - // Double for cos/sin application: [L, 40] -> [L, 80] = [L, headDim] - emb := mlx.Concatenate([]*mlx.Array{freqs, freqs}, 1) - - cos := mlx.Cos(emb) - sin := mlx.Sin(emb) - - return [2]*mlx.Array{cos, sin} -} - -// VLTextBlock is a single Qwen2.5 transformer block (for VL model) -type VLTextBlock struct { - Attention *VLTextAttention - MLP *VLTextMLP - InputLayerNorm *mlx.Array - PostAttnLayerNorm *mlx.Array - NormEps float32 -} - -// newVLTextBlock creates a text block -func newVLTextBlock(weights *safetensors.ModelWeights, layerIdx int, cfg *Qwen25VLConfig) (*VLTextBlock, error) { - prefix := fmt.Sprintf("model.layers.%d", layerIdx) - - inputNorm, err := weights.Get(prefix + ".input_layernorm.weight") - if err != nil { - return nil, err - } - postAttnNorm, err := weights.Get(prefix + ".post_attention_layernorm.weight") - if err != nil { - return nil, err - } - - attention, err := newVLTextAttention(weights, prefix, cfg) - if err != nil { - return nil, err - } - - mlpLayer, err := newVLTextMLP(weights, prefix) - if err != nil { - return nil, err - } - - return &VLTextBlock{ - Attention: attention, - MLP: mlpLayer, - InputLayerNorm: inputNorm, - PostAttnLayerNorm: postAttnNorm, - NormEps: cfg.RMSNormEps, - }, nil -} - -// Forward applies the block -func (tb *VLTextBlock) Forward(x *mlx.Array, cossin [2]*mlx.Array) *mlx.Array { - h := mlx.RMSNorm(x, tb.InputLayerNorm, tb.NormEps) - attnOut := tb.Attention.Forward(h, cossin) - x = mlx.Add(x, attnOut) - - h = mlx.RMSNorm(x, tb.PostAttnLayerNorm, tb.NormEps) - mlpOut := tb.MLP.Forward(h) - x = mlx.Add(x, mlpOut) - - return x -} - -// VLTextAttention implements Qwen2.5 attention with M-RoPE -type VLTextAttention struct { - QProj *mlx.Array - KProj *mlx.Array - VProj *mlx.Array - OProj *mlx.Array - QBias *mlx.Array - KBias *mlx.Array - VBias *mlx.Array - NHeads int32 - NKVHeads int32 - HeadDim int32 - Scale float32 - MRoPESection []int32 -} - -// newVLTextAttention creates a text attention layer -func newVLTextAttention(weights *safetensors.ModelWeights, prefix string, cfg *Qwen25VLConfig) (*VLTextAttention, error) { - qProj, err := weights.Get(prefix + ".self_attn.q_proj.weight") - if err != nil { - return nil, err - } - kProj, err := weights.Get(prefix + ".self_attn.k_proj.weight") - if err != nil { - return nil, err - } - vProj, err := weights.Get(prefix + ".self_attn.v_proj.weight") - if err != nil { - return nil, err - } - oProj, err := weights.Get(prefix + ".self_attn.o_proj.weight") - if err != nil { - return nil, err - } - - qBias, _ := weights.Get(prefix + ".self_attn.q_proj.bias") - kBias, _ := weights.Get(prefix + ".self_attn.k_proj.bias") - vBias, _ := weights.Get(prefix + ".self_attn.v_proj.bias") - - return &VLTextAttention{ - QProj: mlx.Transpose(qProj, 1, 0), - KProj: mlx.Transpose(kProj, 1, 0), - VProj: mlx.Transpose(vProj, 1, 0), - OProj: mlx.Transpose(oProj, 1, 0), - QBias: qBias, - KBias: kBias, - VBias: vBias, - NHeads: cfg.NumAttentionHeads, - NKVHeads: cfg.NumKeyValueHeads, - HeadDim: cfg.HeadDim, - Scale: float32(1.0 / math.Sqrt(float64(cfg.HeadDim))), - MRoPESection: cfg.MRoPESection, - }, nil -} - -// Forward computes attention -func (attn *VLTextAttention) Forward(x *mlx.Array, cossin [2]*mlx.Array) *mlx.Array { - shape := x.Shape() - B := shape[0] - L := shape[1] - - q := mlx.Linear(x, attn.QProj) - if attn.QBias != nil { - q = mlx.Add(q, attn.QBias) - } - k := mlx.Linear(x, attn.KProj) - if attn.KBias != nil { - k = mlx.Add(k, attn.KBias) - } - v := mlx.Linear(x, attn.VProj) - if attn.VBias != nil { - v = mlx.Add(v, attn.VBias) - } - - q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim) - k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim) - v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim) - - q = mlx.Transpose(q, 0, 2, 1, 3) - k = mlx.Transpose(k, 0, 2, 1, 3) - v = mlx.Transpose(v, 0, 2, 1, 3) - - // Apply M-RoPE - if cossin[0] != nil && cossin[1] != nil { - q = applyMRoPE(q, cossin[0], cossin[1], attn.MRoPESection) - k = applyMRoPE(k, cossin[0], cossin[1], attn.MRoPESection) - } - - // Repeat KV for GQA - if attn.NKVHeads < attn.NHeads { - repeats := attn.NHeads / attn.NKVHeads - k = repeatKV(k, repeats) - v = repeatKV(v, repeats) - } - - out := mlx.ScaledDotProductAttention(q, k, v, attn.Scale, true) - - out = mlx.Transpose(out, 0, 2, 1, 3) - out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim) - - return mlx.Linear(out, attn.OProj) -} - -// applyMRoPE applies Multi-Resolution RoPE -func applyMRoPE(x *mlx.Array, cos, sin *mlx.Array, section []int32) *mlx.Array { - shape := x.Shape() - B := shape[0] - H := shape[1] - L := shape[2] - D := shape[3] - half := D / 2 - - fullSection := make([]int32, len(section)) - for i, s := range section { - fullSection[i] = s * 2 - } - - var cosParts, sinParts []*mlx.Array - offset := int32(0) - for i, size := range fullSection { - posDim := int32(i % 3) - cosSection := mlx.Slice(cos, []int32{posDim, 0, 0, offset}, []int32{posDim + 1, B, L, offset + size}) - sinSection := mlx.Slice(sin, []int32{posDim, 0, 0, offset}, []int32{posDim + 1, B, L, offset + size}) - cosSection = mlx.Squeeze(cosSection, 0) - sinSection = mlx.Squeeze(sinSection, 0) - cosParts = append(cosParts, cosSection) - sinParts = append(sinParts, sinSection) - offset += size - } - - cosFlat := mlx.Concatenate(cosParts, 2) - sinFlat := mlx.Concatenate(sinParts, 2) - - cosFlat = mlx.Reshape(cosFlat, B, 1, L, D) - sinFlat = mlx.Reshape(sinFlat, B, 1, L, D) - - x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, H, L, half}) - x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, H, L, D}) - negX2 := mlx.MulScalar(x2, -1) - rotatedX := mlx.Concatenate([]*mlx.Array{negX2, x1}, 3) - - return mlx.Add(mlx.Mul(x, cosFlat), mlx.Mul(rotatedX, sinFlat)) -} - -// repeatKV repeats key/value heads for GQA -func repeatKV(x *mlx.Array, repeats int32) *mlx.Array { - if repeats == 1 { - return x - } - shape := x.Shape() - x = mlx.ExpandDims(x, 2) - x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1}) - return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3]) -} - -// VLTextMLP implements Qwen2.5 SwiGLU MLP -type VLTextMLP struct { - GateProj *mlx.Array - UpProj *mlx.Array - DownProj *mlx.Array -} - -// newVLTextMLP creates a text MLP layer -func newVLTextMLP(weights *safetensors.ModelWeights, prefix string) (*VLTextMLP, error) { - gateProj, err := weights.Get(prefix + ".mlp.gate_proj.weight") - if err != nil { - return nil, err - } - upProj, err := weights.Get(prefix + ".mlp.up_proj.weight") - if err != nil { - return nil, err - } - downProj, err := weights.Get(prefix + ".mlp.down_proj.weight") - if err != nil { - return nil, err - } - - return &VLTextMLP{ - GateProj: mlx.Transpose(gateProj, 1, 0), - UpProj: mlx.Transpose(upProj, 1, 0), - DownProj: mlx.Transpose(downProj, 1, 0), - }, nil -} - -// Forward applies the SwiGLU MLP -func (mlp *VLTextMLP) Forward(x *mlx.Array) *mlx.Array { - gate := mlx.Linear(x, mlp.GateProj) - gate = mlx.SiLU(gate) - up := mlx.Linear(x, mlp.UpProj) - h := mlx.Mul(gate, up) - return mlx.Linear(h, mlp.DownProj) -} - -// VisionPatchEmbed embeds image patches -type VisionPatchEmbed struct { - ProjWeight *mlx.Array - ProjBias *mlx.Array - PatchSize int32 -} - -// newVisionPatchEmbed creates a vision patch embed layer -func newVisionPatchEmbed(weights *safetensors.ModelWeights, cfg *Qwen25VLConfig) (*VisionPatchEmbed, error) { - projWeight, err := weights.Get("visual.patch_embed.proj.weight") - if err != nil { - return nil, err - } - projBias, _ := weights.Get("visual.patch_embed.proj.bias") - - return &VisionPatchEmbed{ - ProjWeight: projWeight, - ProjBias: projBias, - PatchSize: cfg.VisionPatchSize, - }, nil -} - -// Forward embeds patches from an image -// image: [B, C, H, W] -// Returns: [B, num_patches, hidden_size] -func (pe *VisionPatchEmbed) Forward(image *mlx.Array) *mlx.Array { - // Qwen2.5-VL uses 3D conv for patch embedding to support video - // Weight shape is [O, I, kT, kH, kW] e.g. [1280, 3, 2, 14, 14] - // For single image, we duplicate the frame to match temporal_patch_size - - wShape := pe.ProjWeight.Shape() - if len(wShape) == 5 { - // 3D convolution case - temporalPatchSize := wShape[2] // kT from weight shape - - // Add temporal dimension: [B, C, H, W] -> [B, C, 1, H, W] - image = mlx.ExpandDims(image, 2) - - // Duplicate frame to match temporal_patch_size (Python does this for single images) - // [B, C, 1, H, W] -> [B, C, T, H, W] where T = temporal_patch_size - if temporalPatchSize > 1 { - image = mlx.Tile(image, []int32{1, 1, temporalPatchSize, 1, 1}) - } - - // Convert to channels-last: [B, C, T, H, W] -> [B, T, H, W, C] - image = mlx.Transpose(image, 0, 2, 3, 4, 1) - - // Weight is [O, I, kT, kH, kW] - keep as-is since patches are now in [I, kT, kH, kW] order - // (extractPatches3DStrided transposes each patch to [C, T, H, W] to match Python) - - // Apply 3D conv using manual patch extraction - // Strides: (temporal_patch_size, patch_size, patch_size) - x := conv3DStrided(image, pe.ProjWeight, temporalPatchSize, pe.PatchSize, pe.PatchSize) - - if pe.ProjBias != nil { - outC := pe.ProjBias.Dim(0) - bias := mlx.Reshape(pe.ProjBias, 1, 1, 1, 1, outC) - x = mlx.Add(x, bias) - } - - // x is [B, T', H', W', C], squeeze T' and flatten spatial - shape := x.Shape() - // T' should be 1 for single image (since we used stride=temporal_patch_size) - x = mlx.Reshape(x, shape[0], shape[2]*shape[3], shape[4]) - - return x - } - - // Original 2D case (fallback) - // Convert to channels-last for Conv2d - image = mlx.Transpose(image, 0, 2, 3, 1) // [B, H, W, C] - - // Apply conv with stride=patch_size using manual strided convolution - weight := mlx.Transpose(pe.ProjWeight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I] - x := conv2DStrided(image, weight, pe.PatchSize) - if pe.ProjBias != nil { - bias := mlx.Reshape(pe.ProjBias, 1, 1, 1, pe.ProjBias.Dim(0)) - x = mlx.Add(x, bias) - } - - // Flatten patches: [B, pH, pW, C] -> [B, pH*pW, C] - shape := x.Shape() - x = mlx.Reshape(x, shape[0], shape[1]*shape[2], shape[3]) - - return x -} - -// VisionBlock is a single vision transformer block -type VisionBlock struct { - Norm1 *mlx.Array - Norm2 *mlx.Array - Attention *VisionAttention - MLP *VisionMLP -} - -// newVisionBlock creates a vision block -func newVisionBlock(weights *safetensors.ModelWeights, layerIdx int, cfg *Qwen25VLConfig) (*VisionBlock, error) { - prefix := fmt.Sprintf("visual.blocks.%d", layerIdx) - - norm1, err := weights.Get(prefix + ".norm1.weight") - if err != nil { - return nil, err - } - norm2, err := weights.Get(prefix + ".norm2.weight") - if err != nil { - return nil, err - } - - attention, err := newVisionAttention(weights, prefix, cfg) - if err != nil { - return nil, err - } - - mlpLayer, err := newVisionMLP(weights, prefix, cfg) - if err != nil { - return nil, err - } - - return &VisionBlock{ - Norm1: norm1, - Norm2: norm2, - Attention: attention, - MLP: mlpLayer, - }, nil -} - -// Forward applies the vision block -// posEmb: [2]*mlx.Array containing (cos, sin) for RoPE, can be nil -// cuSeqlens: cumulative sequence lengths for window attention -func (vb *VisionBlock) Forward(x *mlx.Array, posEmb [2]*mlx.Array, cuSeqlens []int32) *mlx.Array { - // Python uses RMSNorm, not LayerNorm! - h := mlx.RMSNormNoWeight(x, 1e-6) - h = mlx.Mul(h, vb.Norm1) - attnOut := vb.Attention.Forward(h, posEmb, cuSeqlens) - x = mlx.Add(x, attnOut) - - h = mlx.RMSNormNoWeight(x, 1e-6) - h = mlx.Mul(h, vb.Norm2) - mlpOut := vb.MLP.Forward(h) - x = mlx.Add(x, mlpOut) - - return x -} - -// VisionAttention implements vision attention -type VisionAttention struct { - QKVProj *mlx.Array - QKVBias *mlx.Array - OutProj *mlx.Array - OutBias *mlx.Array - NHeads int32 - HeadDim int32 - Scale float32 -} - -// newVisionAttention creates a vision attention layer -func newVisionAttention(weights *safetensors.ModelWeights, prefix string, cfg *Qwen25VLConfig) (*VisionAttention, error) { - qkvProj, err := weights.Get(prefix + ".attn.qkv.weight") - if err != nil { - return nil, err - } - qkvBias, _ := weights.Get(prefix + ".attn.qkv.bias") - outProj, err := weights.Get(prefix + ".attn.proj.weight") - if err != nil { - return nil, err - } - outBias, _ := weights.Get(prefix + ".attn.proj.bias") - - headDim := cfg.VisionHiddenSize / cfg.VisionNumHeads - - return &VisionAttention{ - QKVProj: mlx.Transpose(qkvProj, 1, 0), - QKVBias: qkvBias, - OutProj: mlx.Transpose(outProj, 1, 0), - OutBias: outBias, - NHeads: cfg.VisionNumHeads, - HeadDim: headDim, - Scale: float32(1.0 / math.Sqrt(float64(headDim))), - }, nil -} - -// Forward applies vision attention with optional RoPE and window attention -// posEmb: [2]*mlx.Array containing (cos, sin) for RoPE, can be nil -// cuSeqlens: cumulative sequence lengths for window boundaries -func (attn *VisionAttention) Forward(x *mlx.Array, posEmb [2]*mlx.Array, cuSeqlens []int32) *mlx.Array { - shape := x.Shape() - B := shape[0] - L := shape[1] - D := shape[2] - - qkv := mlx.Linear(x, attn.QKVProj) - if attn.QKVBias != nil { - qkv = mlx.Add(qkv, attn.QKVBias) - } - - // Split into Q, K, V - qkv = mlx.Reshape(qkv, B, L, 3, attn.NHeads, attn.HeadDim) - q := mlx.Slice(qkv, []int32{0, 0, 0, 0, 0}, []int32{B, L, 1, attn.NHeads, attn.HeadDim}) - k := mlx.Slice(qkv, []int32{0, 0, 1, 0, 0}, []int32{B, L, 2, attn.NHeads, attn.HeadDim}) - v := mlx.Slice(qkv, []int32{0, 0, 2, 0, 0}, []int32{B, L, 3, attn.NHeads, attn.HeadDim}) - - q = mlx.Squeeze(q, 2) // [B, L, H, D] - k = mlx.Squeeze(k, 2) - v = mlx.Squeeze(v, 2) - - // Apply RoPE if position embeddings provided - if posEmb[0] != nil && posEmb[1] != nil { - q, k = applyVisionRoPE(q, k, posEmb[0], posEmb[1]) - } - - q = mlx.Transpose(q, 0, 2, 1, 3) // [B, H, L, D] - k = mlx.Transpose(k, 0, 2, 1, 3) - v = mlx.Transpose(v, 0, 2, 1, 3) - - var out *mlx.Array - - // Check if we need window attention (more than 1 window) - numWindows := len(cuSeqlens) - 1 - if numWindows <= 1 { - // Full attention - single window covering entire sequence - out = mlx.ScaledDotProductAttention(q, k, v, attn.Scale, false) - } else { - // Window attention - process each window separately - attnOutputs := make([]*mlx.Array, numWindows) - - for w := 0; w < numWindows; w++ { - start := cuSeqlens[w] - end := cuSeqlens[w+1] - - // Slice Q, K, V for this window: [B, H, winLen, D] - qWin := mlx.Slice(q, []int32{0, 0, start, 0}, []int32{B, attn.NHeads, end, attn.HeadDim}) - kWin := mlx.Slice(k, []int32{0, 0, start, 0}, []int32{B, attn.NHeads, end, attn.HeadDim}) - vWin := mlx.Slice(v, []int32{0, 0, start, 0}, []int32{B, attn.NHeads, end, attn.HeadDim}) - - // Compute attention for this window - attnWin := mlx.ScaledDotProductAttention(qWin, kWin, vWin, attn.Scale, false) - attnOutputs[w] = attnWin - } - - // Concatenate all window outputs along sequence dimension - out = mlx.Concatenate(attnOutputs, 2) - } - - out = mlx.Transpose(out, 0, 2, 1, 3) // [B, L, H, D] - out = mlx.Reshape(out, B, L, D) - - out = mlx.Linear(out, attn.OutProj) - if attn.OutBias != nil { - out = mlx.Add(out, attn.OutBias) - } - - return out -} - -// applyVisionRoPE applies rotary position embedding to Q and K for vision -// q, k: [B, L, H, D], cos, sin: [L, D] (already doubled: D = head_dim) -// Returns: rotated q, k with same shape -// Note: Python does this computation in float32 for numerical stability -func applyVisionRoPE(q, k, cos, sin *mlx.Array) (*mlx.Array, *mlx.Array) { - // Convert to float32 for numerical stability (matches Python) - origDtype := q.Dtype() - q = mlx.AsType(q, mlx.DtypeFloat32) - k = mlx.AsType(k, mlx.DtypeFloat32) - cos = mlx.AsType(cos, mlx.DtypeFloat32) - sin = mlx.AsType(sin, mlx.DtypeFloat32) - - // Expand cos/sin to match q/k shape: [L, D] -> [1, L, 1, D] - cos = mlx.ExpandDims(cos, 0) - cos = mlx.ExpandDims(cos, 2) - sin = mlx.ExpandDims(sin, 0) - sin = mlx.ExpandDims(sin, 2) - - // rotate_half: split last dim in half and swap with negation - // q_rot = q * cos + rotate_half(q) * sin - qRotated := rotateHalf(q) - kRotated := rotateHalf(k) - - qOut := mlx.Add(mlx.Mul(q, cos), mlx.Mul(qRotated, sin)) - kOut := mlx.Add(mlx.Mul(k, cos), mlx.Mul(kRotated, sin)) - - // Convert back to original dtype - qOut = mlx.AsType(qOut, origDtype) - kOut = mlx.AsType(kOut, origDtype) - - return qOut, kOut -} - -// rotateHalf rotates the last dimension by splitting in half and swapping with negation -// x: [..., D] -> split to [..., D/2] and [..., D/2], then concat(-x2, x1) -func rotateHalf(x *mlx.Array) *mlx.Array { - shape := x.Shape() - lastDim := shape[len(shape)-1] - halfDim := lastDim / 2 - - // Split into two halves - x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{shape[0], shape[1], shape[2], halfDim}) - x2 := mlx.Slice(x, []int32{0, 0, 0, halfDim}, []int32{shape[0], shape[1], shape[2], lastDim}) - - // Negate x2 and concatenate - x2Neg := mlx.MulScalar(x2, -1.0) - return mlx.Concatenate([]*mlx.Array{x2Neg, x1}, 3) -} - -// VisionMLP implements vision SwiGLU MLP -type VisionMLP struct { - GateProj *mlx.Array - GateProjBias *mlx.Array - UpProj *mlx.Array - UpProjBias *mlx.Array - DownProj *mlx.Array - DownProjBias *mlx.Array -} - -// newVisionMLP creates a vision MLP layer -func newVisionMLP(weights *safetensors.ModelWeights, prefix string, cfg *Qwen25VLConfig) (*VisionMLP, error) { - gateProj, err := weights.Get(prefix + ".mlp.gate_proj.weight") - if err != nil { - return nil, err - } - gateProjBias, _ := weights.Get(prefix + ".mlp.gate_proj.bias") - upProj, err := weights.Get(prefix + ".mlp.up_proj.weight") - if err != nil { - return nil, err - } - upProjBias, _ := weights.Get(prefix + ".mlp.up_proj.bias") - downProj, err := weights.Get(prefix + ".mlp.down_proj.weight") - if err != nil { - return nil, err - } - downProjBias, _ := weights.Get(prefix + ".mlp.down_proj.bias") - - return &VisionMLP{ - GateProj: mlx.Transpose(gateProj, 1, 0), - GateProjBias: gateProjBias, - UpProj: mlx.Transpose(upProj, 1, 0), - UpProjBias: upProjBias, - DownProj: mlx.Transpose(downProj, 1, 0), - DownProjBias: downProjBias, - }, nil -} - -// Forward applies the vision SwiGLU MLP -func (m *VisionMLP) Forward(x *mlx.Array) *mlx.Array { - gate := mlx.Linear(x, m.GateProj) - if m.GateProjBias != nil { - gate = mlx.Add(gate, m.GateProjBias) - } - gate = mlx.SiLU(gate) - - up := mlx.Linear(x, m.UpProj) - if m.UpProjBias != nil { - up = mlx.Add(up, m.UpProjBias) - } - - h := mlx.Mul(gate, up) - h = mlx.Linear(h, m.DownProj) - if m.DownProjBias != nil { - h = mlx.Add(h, m.DownProjBias) - } - return h -} - -// VisionMerger merges spatial patches (2x2 -> 1) -type VisionMerger struct { - MLP0Weight *mlx.Array - MLP0Bias *mlx.Array - MLP2Weight *mlx.Array - MLP2Bias *mlx.Array - LNWeight *mlx.Array -} - -// newVisionMerger creates a vision merger -func newVisionMerger(weights *safetensors.ModelWeights, cfg *Qwen25VLConfig) (*VisionMerger, error) { - mlp0Weight, err := weights.Get("visual.merger.mlp.0.weight") - if err != nil { - return nil, err - } - mlp0Bias, _ := weights.Get("visual.merger.mlp.0.bias") - mlp2Weight, err := weights.Get("visual.merger.mlp.2.weight") - if err != nil { - return nil, err - } - mlp2Bias, _ := weights.Get("visual.merger.mlp.2.bias") - lnWeight, _ := weights.Get("visual.merger.ln_q.weight") - - return &VisionMerger{ - MLP0Weight: mlx.Transpose(mlp0Weight, 1, 0), - MLP0Bias: mlp0Bias, - MLP2Weight: mlx.Transpose(mlp2Weight, 1, 0), - MLP2Bias: mlp2Bias, - LNWeight: lnWeight, - }, nil -} - -// Forward merges 2x2 patches into 1 (assumes square grid - use ForwardWithDims for non-square) -func (m *VisionMerger) Forward(x *mlx.Array) *mlx.Array { - shape := x.Shape() - L := shape[1] - side := int32(math.Sqrt(float64(L))) - return m.ForwardWithDims(x, side, side) -} - -// ForwardWithDims merges 2x2 patches into 1 with explicit grid dimensions -// After window reordering, consecutive 4 patches form a 2x2 block, so we just -// reshape [B, L, D] -> [B, L/4, 4*D] without 2D spatial rearrangement. -func (m *VisionMerger) ForwardWithDims(x *mlx.Array, pH, pW int32) *mlx.Array { - shape := x.Shape() - B := shape[0] - L := shape[1] - D := shape[2] - - // RMSNorm BEFORE merge (applied to each token with D dimensions) - // Python: ln_q = Qwen2RMSNorm(context_dim, eps=1e-6) - if m.LNWeight != nil { - x = mlx.RMSNormNoWeight(x, 1e-6) - x = mlx.Mul(x, m.LNWeight) - } - - // After window reordering, consecutive 4 patches belong to a 2x2 block - // Just reshape to [B, L/4, 4*D] - no spatial rearrangement needed - newL := L / 4 - x = mlx.Reshape(x, B, newL, 4*D) - - // MLP - h := mlx.Linear(x, m.MLP0Weight) - if m.MLP0Bias != nil { - h = mlx.Add(h, m.MLP0Bias) - } - h = mlx.GELU(h) - h = mlx.Linear(h, m.MLP2Weight) - if m.MLP2Bias != nil { - h = mlx.Add(h, m.MLP2Bias) - } - - return h -} - -// LoadQwen25VLFromPath loads the encoder from path -func LoadQwen25VLFromPath(path string) (*Qwen25VL, error) { - m := &Qwen25VL{} - if err := m.Load(filepath.Join(path, "text_encoder")); err != nil { - return nil, err - } - return m, nil -} - -// conv2DStrided applies conv with stride > 1 using manual patch extraction -// x: [B, H, W, C] (channels-last), weight: [O, kH, kW, I] -func conv2DStrided(x, weight *mlx.Array, stride int32) *mlx.Array { - shape := x.Shape() - B := shape[0] - H := shape[1] - W := shape[2] - - wShape := weight.Shape() - Cout := wShape[0] - kH := wShape[1] - kW := wShape[2] - - outH := (H - kH) / stride + 1 - outW := (W - kW) / stride + 1 - - patches := extractPatches2DStrided(x, kH, kW, stride) - wFlat := mlx.Reshape(weight, Cout, -1) - patches = mlx.Reshape(patches, B*outH*outW, -1) - out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0)) - return mlx.Reshape(out, B, outH, outW, Cout) -} - -// conv3DStrided applies 3D conv with strides using manual patch extraction -// x: [B, T, H, W, C] (channels-last), weight: [O, I, kT, kH, kW] (PyTorch format) -// strideT, strideH, strideW are the strides for each dimension -// Patches are extracted in [C, T, H, W] order to match Python's preprocessing -func conv3DStrided(x, weight *mlx.Array, strideT, strideH, strideW int32) *mlx.Array { - shape := x.Shape() - B := shape[0] - T := shape[1] - H := shape[2] - W := shape[3] - C := shape[4] - - wShape := weight.Shape() - Cout := wShape[0] - // I := wShape[1] - kT := wShape[2] - kH := wShape[3] - kW := wShape[4] - - // For temporal: if T < kT, we need to repeat frames temporally - // For single image with T=1 and kT=2, we duplicate the frame to T=kT - // Python Qwen2.5-VL duplicates the frame, not zero-pads - if T < kT { - // Tile along T dimension: [B, T, H, W, C] -> [B, kT, H, W, C] - x = mlx.Tile(x, []int32{1, kT, 1, 1, 1}) - T = kT - } - - outT := (T - kT) / strideT + 1 - outH := (H - kH) / strideH + 1 - outW := (W - kW) / strideW + 1 - - // Extract 3D patches in [C, T, H, W] order to match Python - patches := extractPatches3DStrided(x, kT, kH, kW, strideT, strideH, strideW) - // patches shape: [B, outT, outH, outW, C*kT*kH*kW] - - // Weight is [O, I, kT, kH, kW] - flatten to [O, I*kT*kH*kW] to match patch order [C, T, H, W] - wFlat := mlx.Reshape(weight, Cout, -1) // [Cout, I*kT*kH*kW] - patches = mlx.Reshape(patches, B*outT*outH*outW, C*kT*kH*kW) - out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0)) - return mlx.Reshape(out, B, outT, outH, outW, Cout) -} - -// extractPatches3DStrided extracts 3D patches with given strides -// Returns patches with values in [C, T, H, W] order to match Python's preprocessing -func extractPatches3DStrided(x *mlx.Array, kT, kH, kW, strideT, strideH, strideW int32) *mlx.Array { - shape := x.Shape() - B := shape[0] - T := shape[1] - H := shape[2] - W := shape[3] - C := shape[4] - - outT := (T - kT) / strideT + 1 - outH := (H - kH) / strideH + 1 - outW := (W - kW) / strideW + 1 - - numPatches := outT * outH * outW - patches := make([]*mlx.Array, numPatches) - idx := 0 - for t := int32(0); t < outT; t++ { - for i := int32(0); i < outH; i++ { - for j := int32(0); j < outW; j++ { - startT := t * strideT - startH := i * strideH - startW := j * strideW - // Extract patch: [B, kT, kH, kW, C] - patch := mlx.Slice(x, - []int32{0, startT, startH, startW, 0}, - []int32{B, startT + kT, startH + kH, startW + kW, C}) - // Transpose from [B, T, H, W, C] to [B, C, T, H, W] to match Python's order - patch = mlx.Transpose(patch, 0, 4, 1, 2, 3) - // Flatten to [B, C*T*H*W] - patch = mlx.Reshape(patch, B, C*kT*kH*kW) - patches[idx] = patch - idx++ - } - } - } - - for i := range patches { - patches[i] = mlx.ExpandDims(patches[i], 1) - } - stacked := mlx.Concatenate(patches, 1) - return mlx.Reshape(stacked, B, outT, outH, outW, C*kT*kH*kW) -} - -// extractPatches2DStrided extracts patches with given stride -func extractPatches2DStrided(x *mlx.Array, kH, kW, stride int32) *mlx.Array { - shape := x.Shape() - B := shape[0] - H := shape[1] - W := shape[2] - C := shape[3] - - outH := (H - kH) / stride + 1 - outW := (W - kW) / stride + 1 - - patches := make([]*mlx.Array, outH*outW) - idx := 0 - for i := int32(0); i < outH; i++ { - for j := int32(0); j < outW; j++ { - startH := i * stride - startW := j * stride - patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C}) - patch = mlx.Reshape(patch, B, kH*kW*C) - patches[idx] = patch - idx++ - } - } - - for i := range patches { - patches[i] = mlx.ExpandDims(patches[i], 1) - } - stacked := mlx.Concatenate(patches, 1) - return mlx.Reshape(stacked, B, outH, outW, kH*kW*C) -} diff --git a/x/imagegen/models/qwen_image/qwen_image.go b/x/imagegen/models/qwen_image/qwen_image.go deleted file mode 100644 index a7e554623..000000000 --- a/x/imagegen/models/qwen_image/qwen_image.go +++ /dev/null @@ -1,367 +0,0 @@ -//go:build mlx - -// Package qwen_image implements the Qwen-Image diffusion transformer model. -package qwen_image - -import ( - "context" - "fmt" - "path/filepath" - "time" - - "github.com/ollama/ollama/x/imagegen/cache" - "github.com/ollama/ollama/x/imagegen/mlx" - "github.com/ollama/ollama/x/imagegen/tokenizer" -) - -// GenerateConfig holds all options for image generation. -type GenerateConfig struct { - Prompt string - NegativePrompt string // Empty = no CFG - CFGScale float32 // Only used if NegativePrompt is set (default: 4.0) - Width int32 // Image width (default: 1024) - Height int32 // Image height (default: 1024) - Steps int // Denoising steps (default: 30) - Seed int64 // Random seed - Progress func(step, totalSteps int) // Optional progress callback - - // Layer caching (DeepCache/Learning-to-Cache speedup) - LayerCache bool // Enable layer caching (default: false) - CacheInterval int // Refresh cache every N steps (default: 3) - CacheLayers int // Number of shallow layers to cache (default: 25) -} - -// Model represents a Qwen-Image diffusion model. -type Model struct { - ModelPath string - Tokenizer *tokenizer.Tokenizer - TextEncoder *Qwen25VL - Transformer *Transformer - VAEDecoder *VAEDecoder -} - -// Load loads the Qwen-Image model from a directory. -func (m *Model) Load(modelPath string) error { - fmt.Println("Loading Qwen-Image model...") - start := time.Now() - - if mlx.GPUIsAvailable() { - mlx.SetDefaultDeviceGPU() - mlx.EnableCompile() - } - - m.ModelPath = modelPath - - // Load tokenizer - fmt.Print(" Loading tokenizer... ") - tokenizerPath := filepath.Join(modelPath, "tokenizer") - tok, err := tokenizer.Load(tokenizerPath) - if err != nil { - return fmt.Errorf("tokenizer: %w", err) - } - m.Tokenizer = tok - fmt.Println("✓") - - // Load text encoder (Qwen2.5-VL in text-only mode - skip vision tower for efficiency) - m.TextEncoder = &Qwen25VL{} - if err := m.TextEncoder.LoadTextOnly(filepath.Join(modelPath, "text_encoder")); err != nil { - return fmt.Errorf("text encoder: %w", err) - } - mlx.Eval(mlx.Collect(m.TextEncoder)...) - fmt.Printf(" (%.1f GB, peak %.1f GB)\n", - float64(mlx.MetalGetActiveMemory())/(1024*1024*1024), - float64(mlx.MetalGetPeakMemory())/(1024*1024*1024)) - - // Load transformer - m.Transformer = &Transformer{} - if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil { - return fmt.Errorf("transformer: %w", err) - } - mlx.Eval(mlx.Collect(m.Transformer)...) - fmt.Printf(" (%.1f GB, peak %.1f GB)\n", - float64(mlx.MetalGetActiveMemory())/(1024*1024*1024), - float64(mlx.MetalGetPeakMemory())/(1024*1024*1024)) - - // Load VAE decoder - m.VAEDecoder = &VAEDecoder{} - if err := m.VAEDecoder.Load(filepath.Join(modelPath, "vae")); err != nil { - return fmt.Errorf("VAE decoder: %w", err) - } - mlx.Eval(mlx.Collect(m.VAEDecoder)...) - fmt.Printf(" (%.1f GB, peak %.1f GB)\n", - float64(mlx.MetalGetActiveMemory())/(1024*1024*1024), - float64(mlx.MetalGetPeakMemory())/(1024*1024*1024)) - - mem := mlx.MetalGetActiveMemory() - peak := mlx.MetalGetPeakMemory() - fmt.Printf(" Loaded in %.2fs (%.1f GB active, %.1f GB peak)\n", - time.Since(start).Seconds(), - float64(mem)/(1024*1024*1024), - float64(peak)/(1024*1024*1024)) - - return nil -} - -// Generate creates an image from a prompt. -func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) { - return m.GenerateFromConfig(&GenerateConfig{ - Prompt: prompt, - Width: width, - Height: height, - Steps: steps, - Seed: seed, - }) -} - -// GenerateWithProgress creates an image with progress callback. -func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress func(step, totalSteps int)) (*mlx.Array, error) { - return m.GenerateFromConfig(&GenerateConfig{ - Prompt: prompt, - Width: width, - Height: height, - Steps: steps, - Seed: seed, - Progress: progress, - }) -} - -// GenerateWithCFG creates an image with classifier-free guidance. -func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress func(step, totalSteps int)) (*mlx.Array, error) { - return m.GenerateFromConfig(&GenerateConfig{ - Prompt: prompt, - NegativePrompt: negativePrompt, - CFGScale: cfgScale, - Width: width, - Height: height, - Steps: steps, - Seed: seed, - Progress: progress, - }) -} - -// GenerateFromConfig generates an image using the unified config struct. -func (m *Model) GenerateFromConfig(cfg *GenerateConfig) (*mlx.Array, error) { - start := time.Now() - result, err := m.generate(cfg) - if err != nil { - return nil, err - } - if cfg.NegativePrompt != "" { - fmt.Printf("Generated with CFG (scale=%.1f) in %.2fs (%d steps)\n", cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps) - } else { - fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps) - } - return result, nil -} - -// GenerateImage implements model.ImageModel interface. -func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) { - return m.Generate(prompt, width, height, steps, seed) -} - -// generate is the internal denoising pipeline. -func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) { - // Apply defaults - if cfg.Width <= 0 { - cfg.Width = 1024 - } - if cfg.Height <= 0 { - cfg.Height = 1024 - } - if cfg.Steps <= 0 { - cfg.Steps = 50 - } - if cfg.CFGScale <= 0 { - cfg.CFGScale = 4.0 - } - if cfg.CacheInterval <= 0 { - cfg.CacheInterval = 3 - } - if cfg.CacheLayers <= 0 { - cfg.CacheLayers = 25 // ~42% of 60 layers (similar ratio to Z-Image's 15/38) - } - - useCFG := cfg.NegativePrompt != "" - tcfg := m.Transformer.Config - latentH := cfg.Height / 8 - latentW := cfg.Width / 8 - pH := latentH / tcfg.PatchSize - pW := latentW / tcfg.PatchSize - imgSeqLen := pH * pW - - // Text encoding - var posEmb, negEmb *mlx.Array - { - posEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt) - if useCFG { - negEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt) - mlx.Keep(posEmb, negEmb) - mlx.Eval(posEmb, negEmb) - } else { - mlx.Keep(posEmb) - mlx.Eval(posEmb) - } - } - - // Pad sequences to same length for CFG - txtLen := posEmb.Shape()[1] - if useCFG { - negLen := negEmb.Shape()[1] - if negLen > txtLen { - txtLen = negLen - } - if posEmb.Shape()[1] < txtLen { - posEmb = padSequence(posEmb, txtLen) - } - if negEmb.Shape()[1] < txtLen { - negEmb = padSequence(negEmb, txtLen) - } - mlx.Keep(posEmb, negEmb) - } - - // Pre-compute batched embeddings for CFG (single forward pass optimization) - var batchedEmb *mlx.Array - if useCFG { - batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0) - mlx.Keep(batchedEmb) - mlx.Eval(batchedEmb) - } - - // Scheduler - scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig()) - scheduler.SetTimesteps(cfg.Steps, imgSeqLen) - - // Init latents [B, C, T, H, W] - var latents *mlx.Array - { - latents = scheduler.InitNoise([]int32{1, tcfg.OutChannels, 1, latentH, latentW}, cfg.Seed) - mlx.Eval(latents) - } - - // RoPE cache - var ropeCache *RoPECache - { - ropeCache = PrepareRoPE(pH, pW, txtLen, tcfg.AxesDimsRope) - mlx.Keep(ropeCache.ImgFreqs, ropeCache.TxtFreqs) - mlx.Eval(ropeCache.ImgFreqs) - } - - // Layer cache for DeepCache/Learning-to-Cache speedup - var stepCache *cache.StepCache - if cfg.LayerCache { - stepCache = cache.NewStepCache(cfg.CacheLayers) - fmt.Printf(" Layer caching: %d layers, refresh every %d steps\n", cfg.CacheLayers, cfg.CacheInterval) - } - - // Denoising loop - for i := 0; i < cfg.Steps; i++ { - stepStart := time.Now() - if cfg.Progress != nil { - cfg.Progress(i+1, cfg.Steps) - } - - t := scheduler.Timesteps[i] - timestep := mlx.ToBFloat16(mlx.NewArray([]float32{t}, []int32{1})) - - // Squeeze temporal dim: [B, C, T, H, W] -> [B, C, H, W] - latents2D := mlx.Squeeze(latents, 2) - patches := PackLatents(latents2D, tcfg.PatchSize) - - var output *mlx.Array - if useCFG { - // CFG Batching: single forward pass with batch=2 - // Note: layer caching with CFG is not supported yet (would need 2 caches) - batchedPatches := mlx.Tile(patches, []int32{2, 1, 1}) - batchedTimestep := mlx.Tile(timestep, []int32{2}) - - // Single batched forward pass - batchedOutput := m.Transformer.Forward(batchedPatches, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs) - - // Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D] - L := batchedOutput.Shape()[1] - D := batchedOutput.Shape()[2] - posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, L, D}) - negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, L, D}) - - diff := mlx.Sub(posOutput, negOutput) - scaledDiff := mlx.MulScalar(diff, cfg.CFGScale) - combPred := mlx.Add(negOutput, scaledDiff) - - // Norm rescaling: rescale combined prediction to match conditional prediction's norm - condNorm := mlx.Sqrt(mlx.Sum(mlx.Square(posOutput), -1, true)) - combNorm := mlx.Sqrt(mlx.Sum(mlx.Square(combPred), -1, true)) - output = mlx.Mul(combPred, mlx.Div(condNorm, combNorm)) - } else if stepCache != nil { - output = m.Transformer.ForwardWithCache(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs, - stepCache, i, cfg.CacheInterval, cfg.CacheLayers) - } else { - output = m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs) - } - - noisePred := UnpackLatents(output, latentH, latentW, tcfg.PatchSize) - oldLatents := latents - latents = scheduler.Step(noisePred, latents, i) - - // Keep cached arrays alive across cleanup - if stepCache != nil { - mlx.Keep(stepCache.Arrays()...) - } - mlx.Eval(latents) - oldLatents.Free() - - activeMem := float64(mlx.MetalGetActiveMemory()) / (1024 * 1024 * 1024) - peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024) - fmt.Printf(" Step %d/%d: t=%.4f (%.2fs) [%.1f GB active, %.1f GB peak]\n", i+1, cfg.Steps, t, time.Since(stepStart).Seconds(), activeMem, peakMem) - } - - // Free denoising temporaries before VAE decode - posEmb.Free() - if negEmb != nil { - negEmb.Free() - } - if batchedEmb != nil { - batchedEmb.Free() - } - ropeCache.ImgFreqs.Free() - ropeCache.TxtFreqs.Free() - if stepCache != nil { - stepCache.Free() - } - - // VAE decode (Decode manages its own pools for staged memory) - decoded := m.VAEDecoder.Decode(latents) - latents.Free() - // Post-process: squeeze temporal dim and rescale to [0, 1] - { - decoded = mlx.Squeeze(decoded, 2) - decoded = mlx.AddScalar(decoded, 1.0) - decoded = mlx.DivScalar(decoded, 2.0) - mlx.Eval(decoded) - } - - fmt.Printf(" Peak memory: %.2f GB\n", float64(mlx.MetalGetPeakMemory())/(1024*1024*1024)) - - return decoded, nil -} - -// padSequence pads a sequence tensor to the target length with zeros -func padSequence(x *mlx.Array, targetLen int32) *mlx.Array { - shape := x.Shape() - currentLen := shape[1] - if currentLen >= targetLen { - return x - } - padLen := targetLen - currentLen - // Pad on sequence dimension (axis 1) - return mlx.Pad(x, []int32{0, 0, 0, padLen, 0, 0}) -} - -// LoadPersistent is an alias for backward compatibility. -// Use m := &Model{}; m.Load(path) instead. -func LoadPersistent(modelPath string) (*Model, error) { - m := &Model{} - if err := m.Load(modelPath); err != nil { - return nil, err - } - return m, nil -} diff --git a/x/imagegen/models/qwen_image/scheduler.go b/x/imagegen/models/qwen_image/scheduler.go deleted file mode 100644 index d1f0da049..000000000 --- a/x/imagegen/models/qwen_image/scheduler.go +++ /dev/null @@ -1,218 +0,0 @@ -//go:build mlx - -package qwen_image - -import ( - "math" - - "github.com/ollama/ollama/x/imagegen/mlx" -) - -// SchedulerConfig holds FlowMatchEulerDiscreteScheduler configuration -type SchedulerConfig struct { - NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000 - BaseShift float32 `json:"base_shift"` // 0.5 - MaxShift float32 `json:"max_shift"` // 0.9 - BaseImageSeqLen int32 `json:"base_image_seq_len"` // 256 - MaxImageSeqLen int32 `json:"max_image_seq_len"` // 8192 - ShiftTerminal float32 `json:"shift_terminal"` // 0.02 - UseDynamicShift bool `json:"use_dynamic_shifting"` // true -} - -// DefaultSchedulerConfig returns config for FlowMatchEulerDiscreteScheduler -func DefaultSchedulerConfig() *SchedulerConfig { - return &SchedulerConfig{ - NumTrainTimesteps: 1000, - BaseShift: 0.5, - MaxShift: 0.9, // Matches scheduler_config.json - BaseImageSeqLen: 256, - MaxImageSeqLen: 8192, - ShiftTerminal: 0.02, - UseDynamicShift: true, - } -} - -// FlowMatchScheduler implements the Flow Match Euler discrete scheduler -type FlowMatchScheduler struct { - Config *SchedulerConfig - Timesteps []float32 - Sigmas []float32 - NumSteps int -} - -// NewFlowMatchScheduler creates a new scheduler -func NewFlowMatchScheduler(cfg *SchedulerConfig) *FlowMatchScheduler { - return &FlowMatchScheduler{ - Config: cfg, - } -} - -// CalculateShift computes the dynamic shift based on image sequence length -// This matches Python's calculate_shift function -func CalculateShift(imageSeqLen int32, baseSeqLen int32, maxSeqLen int32, baseShift float32, maxShift float32) float32 { - m := (maxShift - baseShift) / float32(maxSeqLen-baseSeqLen) - b := baseShift - m*float32(baseSeqLen) - mu := float32(imageSeqLen)*m + b - return mu -} - -// SetTimesteps sets up the scheduler for the given number of inference steps -// Matches Python diffusers FlowMatchEulerDiscreteScheduler behavior: -// 1. Create sigmas from sigma_max to sigma_min (linspace) -// 2. Apply time_shift with mu (if dynamic shifting) -// 3. Apply stretch_shift_to_terminal to make final value = shift_terminal -func (s *FlowMatchScheduler) SetTimesteps(numSteps int, imageSeqLen int32) { - s.NumSteps = numSteps - - // Calculate mu for dynamic shifting - var mu float32 - if s.Config.UseDynamicShift { - mu = CalculateShift( - imageSeqLen, - s.Config.BaseImageSeqLen, - s.Config.MaxImageSeqLen, - s.Config.BaseShift, - s.Config.MaxShift, - ) - } - - // Step 1: Create sigmas from 1.0 to 1/num_steps - // Python (pipeline_qwenimage.py:639): - // sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - // This gives sigmas from 1.0 to 1/30 = 0.033 for 30 steps - sigmas := make([]float32, numSteps) - sigmaMax := float32(1.0) - sigmaMin := 1.0 / float32(numSteps) // 1/30 = 0.033 for 30 steps - if numSteps == 1 { - sigmas[0] = sigmaMax - } else { - for i := 0; i < numSteps; i++ { - sigmas[i] = sigmaMax + float32(i)*(sigmaMin-sigmaMax)/float32(numSteps-1) - } - } - - // Step 2: Apply time shift if using dynamic shifting - if s.Config.UseDynamicShift && mu != 0 { - for i := range sigmas { - sigmas[i] = s.timeShift(mu, sigmas[i]) - } - } - - // Step 3: Apply stretch_shift_to_terminal - if s.Config.ShiftTerminal > 0 { - sigmas = s.stretchShiftToTerminal(sigmas) - } - - // Step 4: Append terminal sigma (0) and store - // Note: Python's scheduler.timesteps are sigmas*1000, but the pipeline divides by 1000 - // before passing to transformer. We skip both steps and just use sigmas directly. - s.Sigmas = make([]float32, numSteps+1) - s.Timesteps = make([]float32, numSteps+1) - for i := 0; i < numSteps; i++ { - s.Sigmas[i] = sigmas[i] - s.Timesteps[i] = sigmas[i] - } - s.Sigmas[numSteps] = 0.0 - s.Timesteps[numSteps] = 0.0 -} - -// stretchShiftToTerminal stretches and shifts the timestep schedule -// so the final value equals shift_terminal (matches Python behavior) -func (s *FlowMatchScheduler) stretchShiftToTerminal(sigmas []float32) []float32 { - if len(sigmas) == 0 { - return sigmas - } - - // one_minus_z = 1 - t - // scale_factor = one_minus_z[-1] / (1 - shift_terminal) - // stretched_t = 1 - (one_minus_z / scale_factor) - lastSigma := sigmas[len(sigmas)-1] - scaleFactor := (1.0 - lastSigma) / (1.0 - s.Config.ShiftTerminal) - - // Handle edge case: if scaleFactor is 0 or near 0, skip stretch - // This happens when lastSigma ≈ 1.0 (e.g., single step with timeshift) - if scaleFactor < 1e-6 { - return sigmas - } - - result := make([]float32, len(sigmas)) - for i, t := range sigmas { - oneMinusZ := 1.0 - t - result[i] = 1.0 - (oneMinusZ / scaleFactor) - } - return result -} - -// timeShift applies the dynamic time shift (exponential) -// exp(mu) / (exp(mu) + (1/t - 1)) -func (s *FlowMatchScheduler) timeShift(mu float32, t float32) float32 { - if t <= 0 { - return 0 - } - expMu := float32(math.Exp(float64(mu))) - return expMu / (expMu + (1.0/t - 1.0)) -} - -// Step performs one denoising step -// modelOutput: predicted velocity from the transformer -// sample: current noisy sample -// timestepIdx: current timestep index -func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array { - // Get current and next sigma - sigma := s.Sigmas[timestepIdx] - sigmaNext := s.Sigmas[timestepIdx+1] - - // Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t - dt := sigmaNext - sigma - - // Upcast to float32 to avoid precision issues (matches Python diffusers) - sampleF32 := mlx.AsType(sample, mlx.DtypeFloat32) - modelOutputF32 := mlx.AsType(modelOutput, mlx.DtypeFloat32) - - scaledOutput := mlx.MulScalar(modelOutputF32, dt) - result := mlx.Add(sampleF32, scaledOutput) - - // Cast back to original dtype - return mlx.ToBFloat16(result) -} - -// GetTimestep returns the timestep value at the given index -func (s *FlowMatchScheduler) GetTimestep(idx int) float32 { - if idx < len(s.Timesteps) { - return s.Timesteps[idx] - } - return 0.0 -} - -// InitNoise creates initial noise for sampling in unpacked format [B, C, T, H, W] -func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array { - return mlx.RandomNormal(shape, uint64(seed)) -} - -// InitNoisePacked creates initial noise directly in packed format [B, L, C*4] -// This matches how Python diffusers generates noise - directly in packed space. -// Generating in unpacked format and then packing produces different spatial -// correlation structure, which affects model output quality. -func (s *FlowMatchScheduler) InitNoisePacked(batchSize, seqLen, channels int32, seed int64) *mlx.Array { - shape := []int32{batchSize, seqLen, channels} - return mlx.RandomNormal(shape, uint64(seed)) -} - -// GetLatentShape returns the latent shape for a given image size -// For qwen_image: VAE downscale is 8x (spatial), latent has 16 channels -func GetLatentShape(batchSize, height, width int32) []int32 { - latentH := height / 8 - latentW := width / 8 - return []int32{batchSize, 16, 1, latentH, latentW} // [B, C, T, H, W] -} - -// GetPatchedLatentShape returns the patchified latent shape -// After patchification: [B, L, C*patch_size^2] where L = H/2 * W/2 -func GetPatchedLatentShape(batchSize, height, width, patchSize int32) []int32 { - latentH := height / 8 - latentW := width / 8 - pH := latentH / patchSize - pW := latentW / patchSize - inChannels := int32(64) // 16 * patch_size^2 - return []int32{batchSize, pH * pW, inChannels} -} diff --git a/x/imagegen/models/qwen_image/scheduler_test.go b/x/imagegen/models/qwen_image/scheduler_test.go deleted file mode 100644 index 46adeb99a..000000000 --- a/x/imagegen/models/qwen_image/scheduler_test.go +++ /dev/null @@ -1,135 +0,0 @@ -//go:build mlx - -package qwen_image - -import ( - "math" - "testing" -) - -// TestSchedulerSetTimesteps verifies scheduler sigmas match Python diffusers reference. -// Golden values generated via: -// -// python3 -c " -// from diffusers.schedulers import FlowMatchEulerDiscreteScheduler -// import numpy as np -// s = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, base_shift=0.5, max_shift=0.9, -// base_image_seq_len=256, max_image_seq_len=8192, shift_terminal=0.02, use_dynamic_shifting=True) -// mu = 4096 * (0.9-0.5)/(8192-256) + 0.5 - (0.9-0.5)/(8192-256)*256 -// sigmas = np.linspace(1.0, 1.0/30, 30) -// s.set_timesteps(sigmas=sigmas, mu=mu) -// print(s.sigmas.numpy())" -func TestSchedulerSetTimesteps(t *testing.T) { - cfg := DefaultSchedulerConfig() - scheduler := NewFlowMatchScheduler(cfg) - scheduler.SetTimesteps(30, 4096) - - // Golden values from Python diffusers (first 3, last 3 before terminal) - wantFirst := []float32{1.000000, 0.982251, 0.963889} - wantLast := []float32{0.142924, 0.083384, 0.020000} - - // Check first 3 - for i, want := range wantFirst { - got := scheduler.Sigmas[i] - if abs32(got-want) > 1e-4 { - t.Errorf("sigma[%d]: got %v, want %v", i, got, want) - } - } - - // Check last 3 (indices 27, 28, 29) - for i, want := range wantLast { - idx := 27 + i - got := scheduler.Sigmas[idx] - if abs32(got-want) > 1e-4 { - t.Errorf("sigma[%d]: got %v, want %v", idx, got, want) - } - } - - // Check terminal is 0 - if scheduler.Sigmas[30] != 0.0 { - t.Errorf("terminal sigma: got %v, want 0", scheduler.Sigmas[30]) - } - - // Check length - if len(scheduler.Sigmas) != 31 { - t.Errorf("sigmas length: got %d, want 31", len(scheduler.Sigmas)) - } -} - -// TestSchedulerProperties tests mathematical invariants of the scheduler. -func TestSchedulerProperties(t *testing.T) { - cfg := DefaultSchedulerConfig() - scheduler := NewFlowMatchScheduler(cfg) - scheduler.SetTimesteps(30, 4096) - - // Property: sigmas monotonically decreasing - for i := 1; i < len(scheduler.Sigmas); i++ { - if scheduler.Sigmas[i] > scheduler.Sigmas[i-1] { - t.Errorf("sigmas not monotonically decreasing at %d: %v > %v", - i, scheduler.Sigmas[i], scheduler.Sigmas[i-1]) - } - } - - // Property: first sigma should be ~1.0 (with time shift) - if scheduler.Sigmas[0] < 0.9 || scheduler.Sigmas[0] > 1.01 { - t.Errorf("first sigma out of expected range [0.9, 1.01]: %v", scheduler.Sigmas[0]) - } - - // Property: terminal sigma should be exactly 0 - if scheduler.Sigmas[len(scheduler.Sigmas)-1] != 0.0 { - t.Errorf("terminal sigma should be 0, got %v", scheduler.Sigmas[len(scheduler.Sigmas)-1]) - } - - // Property: last non-terminal sigma should be shift_terminal (0.02) - lastNonTerminal := scheduler.Sigmas[len(scheduler.Sigmas)-2] - if abs32(lastNonTerminal-0.02) > 1e-5 { - t.Errorf("last non-terminal sigma should be 0.02, got %v", lastNonTerminal) - } - - // Property: length = steps + 1 - if len(scheduler.Sigmas) != scheduler.NumSteps+1 { - t.Errorf("sigmas length should be steps+1: got %d, want %d", - len(scheduler.Sigmas), scheduler.NumSteps+1) - } -} - -// TestCalculateShift verifies the mu calculation against Python reference. -// Golden values from: mu = img_seq_len * m + b where m = (max_shift - base_shift) / (max_seq_len - base_seq_len) -func TestCalculateShift(t *testing.T) { - cases := []struct { - imgSeqLen int32 - want float32 - }{ - {256, 0.5}, // base case - {8192, 0.9}, // max case - {4096, 0.6935}, // middle case (rounded) - } - - for _, c := range cases { - got := CalculateShift(c.imgSeqLen, 256, 8192, 0.5, 0.9) - if abs32(got-c.want) > 0.001 { - t.Errorf("CalculateShift(%d): got %v, want %v", c.imgSeqLen, got, c.want) - } - } -} - -// TestSchedulerStep verifies the Euler step formula. -func TestSchedulerStep(t *testing.T) { - cfg := DefaultSchedulerConfig() - scheduler := NewFlowMatchScheduler(cfg) - scheduler.SetTimesteps(30, 4096) - - // Verify dt calculation for first step - sigma0 := scheduler.Sigmas[0] - sigma1 := scheduler.Sigmas[1] - expectedDt := sigma1 - sigma0 - - // dt should be negative (sigmas decrease) - if expectedDt >= 0 { - t.Errorf("expected negative dt, got %v (sigma0=%v, sigma1=%v)", expectedDt, sigma0, sigma1) - } -} - -func abs32(x float32) float32 { - return float32(math.Abs(float64(x))) -} diff --git a/x/imagegen/models/qwen_image/text_encoder_test.go b/x/imagegen/models/qwen_image/text_encoder_test.go deleted file mode 100644 index 7704513c8..000000000 --- a/x/imagegen/models/qwen_image/text_encoder_test.go +++ /dev/null @@ -1,174 +0,0 @@ -//go:build mlx - -package qwen_image - -import ( - "encoding/json" - "math" - "os" - "path/filepath" - "slices" - "testing" - - "github.com/ollama/ollama/x/imagegen/mlx" - "github.com/ollama/ollama/x/imagegen/safetensors" -) - -// TinyTextEncoderConfig holds config for the tiny test text encoder -type TinyTextEncoderConfig struct { - HiddenSize int32 `json:"hidden_size"` - NumHiddenLayers int32 `json:"num_hidden_layers"` - IntermediateSize int32 `json:"intermediate_size"` - NumAttentionHeads int32 `json:"num_attention_heads"` - NumKeyValueHeads int32 `json:"num_key_value_heads"` - VocabSize int32 `json:"vocab_size"` - RMSNormEps float32 `json:"rms_norm_eps"` - RopeTheta float32 `json:"rope_theta"` - HeadDim int32 `json:"head_dim"` - MRoPESection []int32 `json:"mrope_section"` -} - -// loadTinyTextEncoder loads the tiny text encoder from testdata -func loadTinyTextEncoder(t *testing.T) (*Qwen25VL, *TinyTextEncoderConfig) { - t.Helper() - - testdataDir := filepath.Join("testdata", "tiny_text_encoder") - - // Load config - configData, err := os.ReadFile(filepath.Join(testdataDir, "config.json")) - if err != nil { - t.Skipf("Skipping: tiny weights not found. Regenerate with Python (see models/CLAUDE.md)") - } - - var tinyCfg TinyTextEncoderConfig - if err := json.Unmarshal(configData, &tinyCfg); err != nil { - t.Fatalf("Failed to parse config: %v", err) - } - - // Create encoder config (using Qwen25VLConfig) - cfg := &Qwen25VLConfig{ - HiddenSize: tinyCfg.HiddenSize, - NumHiddenLayers: tinyCfg.NumHiddenLayers, - IntermediateSize: tinyCfg.IntermediateSize, - NumAttentionHeads: tinyCfg.NumAttentionHeads, - NumKeyValueHeads: tinyCfg.NumKeyValueHeads, - VocabSize: tinyCfg.VocabSize, - RMSNormEps: tinyCfg.RMSNormEps, - RopeTheta: tinyCfg.RopeTheta, - HeadDim: tinyCfg.HeadDim, - MRoPESection: tinyCfg.MRoPESection, - } - - // Load weights - weights, err := safetensors.LoadModelWeights(testdataDir) - if err != nil { - t.Fatalf("Failed to load weights: %v", err) - } - - if err := weights.Load(mlx.DtypeBFloat16); err != nil { - t.Fatalf("Failed to bulk load weights: %v", err) - } - - // Build encoder - embedding, err := weights.Get("model.embed_tokens.weight") - if err != nil { - t.Fatalf("Failed to get embedding: %v", err) - } - - blocks := make([]*VLTextBlock, cfg.NumHiddenLayers) - for i := int32(0); i < cfg.NumHiddenLayers; i++ { - block, err := newVLTextBlock(weights, int(i), cfg) - if err != nil { - t.Fatalf("Failed to load block %d: %v", i, err) - } - blocks[i] = block - } - - finalNorm, err := weights.Get("model.norm.weight") - if err != nil { - t.Fatalf("Failed to get final norm: %v", err) - } - - encoder := &Qwen25VL{ - Config: cfg, - Embedding: embedding, - Blocks: blocks, - FinalNorm: finalNorm, - HasVision: false, // Text-only mode - } - - return encoder, &tinyCfg -} - -// TestTextEncoderForward verifies the text encoder forward pass with tiny weights. -func TestTextEncoderForward(t *testing.T) { - encoder, cfg := loadTinyTextEncoder(t) - - // Create test tokens (within vocab range) - tokens := []int32{1, 2, 3, 4, 5} - - // Forward pass using EncodeTextOnly - out := encoder.EncodeTextOnly(tokens) - mlx.Eval(out) - - // Verify output shape: [batch, seq_len, hidden_size] - wantShape := []int32{1, 5, cfg.HiddenSize} - if !slices.Equal(out.Shape(), wantShape) { - t.Errorf("output shape: got %v, want %v", out.Shape(), wantShape) - } - - // Verify output is finite (not NaN or Inf) - data := out.Data() - for i, v := range data { - if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) { - t.Errorf("output[%d] is not finite: %v", i, v) - break - } - } -} - -// TestTextEncoderBatch tests batch processing. -func TestTextEncoderBatch(t *testing.T) { - encoder, cfg := loadTinyTextEncoder(t) - - // For batch test, we'll use EncodeTextOnly with a single sequence - // (EncodeTextOnly doesn't support batch, but we can verify single sequence works) - tokens := []int32{1, 2, 3} - - out := encoder.EncodeTextOnly(tokens) - mlx.Eval(out) - - wantShape := []int32{1, 3, cfg.HiddenSize} - if !slices.Equal(out.Shape(), wantShape) { - t.Errorf("shape: got %v, want %v", out.Shape(), wantShape) - } -} - -// TestMRoPEComputation verifies M-RoPE frequency computation produces valid values. -func TestMRoPEComputation(t *testing.T) { - encoder, cfg := loadTinyTextEncoder(t) - - cossin := encoder.computeTextRoPE(10, 1) - mlx.Eval(cossin[0], cossin[1]) - - // Verify shapes: [3, B, L, head_dim] - wantShape := []int32{3, 1, 10, cfg.HeadDim} - if !slices.Equal(cossin[0].Shape(), wantShape) { - t.Errorf("cos shape: got %v, want %v", cossin[0].Shape(), wantShape) - } - if !slices.Equal(cossin[1].Shape(), wantShape) { - t.Errorf("sin shape: got %v, want %v", cossin[1].Shape(), wantShape) - } - - // Verify cos/sin values are in valid range [-1, 1] - cosData := cossin[0].Data() - sinData := cossin[1].Data() - for i := 0; i < min(100, len(cosData)); i++ { - if cosData[i] < -1.01 || cosData[i] > 1.01 { - t.Errorf("cos[%d] out of range: %v", i, cosData[i]) - } - if sinData[i] < -1.01 || sinData[i] > 1.01 { - t.Errorf("sin[%d] out of range: %v", i, sinData[i]) - } - } -} diff --git a/x/imagegen/models/qwen_image/transformer.go b/x/imagegen/models/qwen_image/transformer.go deleted file mode 100644 index 06e677619..000000000 --- a/x/imagegen/models/qwen_image/transformer.go +++ /dev/null @@ -1,868 +0,0 @@ -//go:build mlx - -package qwen_image - -import ( - "fmt" - "math" - "path/filepath" - - "github.com/ollama/ollama/x/imagegen/cache" - "github.com/ollama/ollama/x/imagegen/mlx" - "github.com/ollama/ollama/x/imagegen/safetensors" -) - -// TransformerConfig holds Qwen-Image transformer configuration -type TransformerConfig struct { - HiddenDim int32 `json:"hidden_dim"` // 3072 (24 * 128) - NHeads int32 `json:"num_attention_heads"` // 24 - HeadDim int32 `json:"attention_head_dim"` // 128 - NLayers int32 `json:"num_layers"` // 60 - InChannels int32 `json:"in_channels"` // 64 - OutChannels int32 `json:"out_channels"` // 16 - PatchSize int32 `json:"patch_size"` // 2 - JointAttentionDim int32 `json:"joint_attention_dim"` // 3584 (text encoder dim) - NormEps float32 `json:"norm_eps"` // 1e-6 - AxesDimsRope []int32 `json:"axes_dims_rope"` // [16, 56, 56] - GuidanceEmbeds bool `json:"guidance_embeds"` // false -} - -// defaultTransformerConfig returns config for Qwen-Image transformer -func defaultTransformerConfig() *TransformerConfig { - return &TransformerConfig{ - HiddenDim: 3072, // 24 * 128 - NHeads: 24, - HeadDim: 128, - NLayers: 60, - InChannels: 64, - OutChannels: 16, - PatchSize: 2, - JointAttentionDim: 3584, - NormEps: 1e-6, - AxesDimsRope: []int32{16, 56, 56}, - GuidanceEmbeds: false, - } -} - -// TimestepEmbedder creates timestep embeddings -type TimestepEmbedder struct { - Linear1Weight *mlx.Array // [256, hidden_dim] - Linear1Bias *mlx.Array - Linear2Weight *mlx.Array // [hidden_dim, hidden_dim] - Linear2Bias *mlx.Array -} - -// newTimestepEmbedder creates a timestep embedder from weights -func newTimestepEmbedder(weights *safetensors.ModelWeights) (*TimestepEmbedder, error) { - linear1Weight, err := weights.Get("time_text_embed.timestep_embedder.linear_1.weight") - if err != nil { - return nil, err - } - linear1Bias, err := weights.Get("time_text_embed.timestep_embedder.linear_1.bias") - if err != nil { - return nil, err - } - linear2Weight, err := weights.Get("time_text_embed.timestep_embedder.linear_2.weight") - if err != nil { - return nil, err - } - linear2Bias, err := weights.Get("time_text_embed.timestep_embedder.linear_2.bias") - if err != nil { - return nil, err - } - - return &TimestepEmbedder{ - Linear1Weight: mlx.Transpose(linear1Weight, 1, 0), - Linear1Bias: linear1Bias, - Linear2Weight: mlx.Transpose(linear2Weight, 1, 0), - Linear2Bias: linear2Bias, - }, nil -} - -// Forward computes timestep embeddings -// t: [B] timesteps (normalized 0-1, will be scaled by 1000 internally) -func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array { - half := int32(128) // embedding_dim / 2 - - // Sinusoidal embedding with flip_sin_to_cos=True, scale=1000 - freqs := make([]float32, half) - for i := int32(0); i < half; i++ { - freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half))) - } - freqsArr := mlx.NewArray(freqs, []int32{1, half}) - - tExpanded := mlx.ExpandDims(t, 1) - args := mlx.Mul(tExpanded, freqsArr) - args = mlx.MulScalar(args, 1000.0) // scale - - // [cos, sin] (flip_sin_to_cos=True) - sinArgs := mlx.Sin(args) - cosArgs := mlx.Cos(args) - embedding := mlx.Concatenate([]*mlx.Array{cosArgs, sinArgs}, 1) // [B, 256] - - // MLP: linear1 -> silu -> linear2 - h := mlx.Linear(embedding, te.Linear1Weight) - h = mlx.Add(h, te.Linear1Bias) - h = mlx.SiLU(h) - h = mlx.Linear(h, te.Linear2Weight) - h = mlx.Add(h, te.Linear2Bias) - - return h -} - -// JointAttention implements dual-stream joint attention -type JointAttention struct { - // Image projections - ToQ *mlx.Array - ToQB *mlx.Array - ToK *mlx.Array - ToKB *mlx.Array - ToV *mlx.Array - ToVB *mlx.Array - ToOut *mlx.Array - ToOutB *mlx.Array - NormQ *mlx.Array - NormK *mlx.Array - - // Text (added) projections - AddQProj *mlx.Array - AddQProjB *mlx.Array - AddKProj *mlx.Array - AddKProjB *mlx.Array - AddVProj *mlx.Array - AddVProjB *mlx.Array - ToAddOut *mlx.Array - ToAddOutB *mlx.Array - NormAddQ *mlx.Array - NormAddK *mlx.Array - - NHeads int32 - HeadDim int32 - Scale float32 -} - -// newJointAttention creates a joint attention layer -func newJointAttention(weights *safetensors.ModelWeights, prefix string, cfg *TransformerConfig) (*JointAttention, error) { - toQ, _ := weights.Get(prefix + ".attn.to_q.weight") - toQB, _ := weights.Get(prefix + ".attn.to_q.bias") - toK, _ := weights.Get(prefix + ".attn.to_k.weight") - toKB, _ := weights.Get(prefix + ".attn.to_k.bias") - toV, _ := weights.Get(prefix + ".attn.to_v.weight") - toVB, _ := weights.Get(prefix + ".attn.to_v.bias") - toOut, _ := weights.Get(prefix + ".attn.to_out.0.weight") - toOutB, _ := weights.Get(prefix + ".attn.to_out.0.bias") - normQ, _ := weights.Get(prefix + ".attn.norm_q.weight") - normK, _ := weights.Get(prefix + ".attn.norm_k.weight") - - addQProj, _ := weights.Get(prefix + ".attn.add_q_proj.weight") - addQProjB, _ := weights.Get(prefix + ".attn.add_q_proj.bias") - addKProj, _ := weights.Get(prefix + ".attn.add_k_proj.weight") - addKProjB, _ := weights.Get(prefix + ".attn.add_k_proj.bias") - addVProj, _ := weights.Get(prefix + ".attn.add_v_proj.weight") - addVProjB, _ := weights.Get(prefix + ".attn.add_v_proj.bias") - toAddOut, _ := weights.Get(prefix + ".attn.to_add_out.weight") - toAddOutB, _ := weights.Get(prefix + ".attn.to_add_out.bias") - normAddQ, _ := weights.Get(prefix + ".attn.norm_added_q.weight") - normAddK, _ := weights.Get(prefix + ".attn.norm_added_k.weight") - - return &JointAttention{ - ToQ: mlx.Transpose(toQ, 1, 0), - ToQB: toQB, - ToK: mlx.Transpose(toK, 1, 0), - ToKB: toKB, - ToV: mlx.Transpose(toV, 1, 0), - ToVB: toVB, - ToOut: mlx.Transpose(toOut, 1, 0), - ToOutB: toOutB, - NormQ: normQ, - NormK: normK, - AddQProj: mlx.Transpose(addQProj, 1, 0), - AddQProjB: addQProjB, - AddKProj: mlx.Transpose(addKProj, 1, 0), - AddKProjB: addKProjB, - AddVProj: mlx.Transpose(addVProj, 1, 0), - AddVProjB: addVProjB, - ToAddOut: mlx.Transpose(toAddOut, 1, 0), - ToAddOutB: toAddOutB, - NormAddQ: normAddQ, - NormAddK: normAddK, - NHeads: cfg.NHeads, - HeadDim: cfg.HeadDim, - Scale: float32(1.0 / math.Sqrt(float64(cfg.HeadDim))), - }, nil -} - -// Forward computes joint attention -// img: [B, L_img, D], txt: [B, L_txt, D] -// imgFreqs, txtFreqs: complex RoPE frequencies [L, head_dim/2] as interleaved real/imag -func (attn *JointAttention) Forward(img, txt *mlx.Array, imgFreqs, txtFreqs *mlx.Array) (*mlx.Array, *mlx.Array) { - imgShape := img.Shape() - B := imgShape[0] - Limg := imgShape[1] - D := imgShape[2] - - txtShape := txt.Shape() - Ltxt := txtShape[1] - - // === Image Q/K/V === - imgFlat := mlx.Reshape(img, B*Limg, D) - qImg := mlx.Add(mlx.Linear(imgFlat, attn.ToQ), attn.ToQB) - kImg := mlx.Add(mlx.Linear(imgFlat, attn.ToK), attn.ToKB) - vImg := mlx.Add(mlx.Linear(imgFlat, attn.ToV), attn.ToVB) - - qImg = mlx.Reshape(qImg, B, Limg, attn.NHeads, attn.HeadDim) - kImg = mlx.Reshape(kImg, B, Limg, attn.NHeads, attn.HeadDim) - vImg = mlx.Reshape(vImg, B, Limg, attn.NHeads, attn.HeadDim) - - // QK norm (RMSNorm per head) - qImg = mlx.RMSNorm(qImg, attn.NormQ, 1e-6) - kImg = mlx.RMSNorm(kImg, attn.NormK, 1e-6) - - // Apply RoPE - if imgFreqs != nil { - qImg = applyRoPE(qImg, imgFreqs) - kImg = applyRoPE(kImg, imgFreqs) - } - - // === Text Q/K/V === - txtFlat := mlx.Reshape(txt, B*Ltxt, D) - qTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddQProj), attn.AddQProjB) - kTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddKProj), attn.AddKProjB) - vTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddVProj), attn.AddVProjB) - - qTxt = mlx.Reshape(qTxt, B, Ltxt, attn.NHeads, attn.HeadDim) - kTxt = mlx.Reshape(kTxt, B, Ltxt, attn.NHeads, attn.HeadDim) - vTxt = mlx.Reshape(vTxt, B, Ltxt, attn.NHeads, attn.HeadDim) - - qTxt = mlx.RMSNorm(qTxt, attn.NormAddQ, 1e-6) - kTxt = mlx.RMSNorm(kTxt, attn.NormAddK, 1e-6) - - if txtFreqs != nil { - qTxt = applyRoPE(qTxt, txtFreqs) - kTxt = applyRoPE(kTxt, txtFreqs) - } - - // Concatenate for joint attention: [txt, img] order - qJoint := mlx.Concatenate([]*mlx.Array{qTxt, qImg}, 1) - kJoint := mlx.Concatenate([]*mlx.Array{kTxt, kImg}, 1) - vJoint := mlx.Concatenate([]*mlx.Array{vTxt, vImg}, 1) - - // Transpose to [B, nheads, L, head_dim] - qJoint = mlx.Transpose(qJoint, 0, 2, 1, 3) - kJoint = mlx.Transpose(kJoint, 0, 2, 1, 3) - vJoint = mlx.Transpose(vJoint, 0, 2, 1, 3) - - // SDPA - outJoint := mlx.ScaledDotProductAttention(qJoint, kJoint, vJoint, attn.Scale, false) - - // Transpose back and split - outJoint = mlx.Transpose(outJoint, 0, 2, 1, 3) // [B, L, nheads, head_dim] - outJoint = mlx.Reshape(outJoint, B, Ltxt+Limg, D) - - outTxt := mlx.Slice(outJoint, []int32{0, 0, 0}, []int32{B, Ltxt, D}) - outImg := mlx.Slice(outJoint, []int32{0, Ltxt, 0}, []int32{B, Ltxt + Limg, D}) - - // Output projections - outImg = mlx.Reshape(outImg, B*Limg, D) - outImg = mlx.Add(mlx.Linear(outImg, attn.ToOut), attn.ToOutB) - outImg = mlx.Reshape(outImg, B, Limg, D) - - outTxt = mlx.Reshape(outTxt, B*Ltxt, D) - outTxt = mlx.Add(mlx.Linear(outTxt, attn.ToAddOut), attn.ToAddOutB) - outTxt = mlx.Reshape(outTxt, B, Ltxt, D) - - return outImg, outTxt -} - -// applyRoPE applies rotary embeddings using complex multiplication -// x: [B, L, nheads, head_dim] -// freqs: [L, head_dim] as complex (interleaved real/imag pairs) -func applyRoPE(x *mlx.Array, freqs *mlx.Array) *mlx.Array { - shape := x.Shape() - B := shape[0] - L := shape[1] - nheads := shape[2] - headDim := shape[3] - halfDim := headDim / 2 - - // Reshape x to pairs: [B, L, nheads, half, 2] - xPairs := mlx.Reshape(x, B, L, nheads, halfDim, 2) - - // freqs: [L, head_dim] -> [1, L, 1, half, 2] - freqsExp := mlx.Reshape(freqs, 1, L, 1, halfDim, 2) - - // Extract real/imag parts - xReal := mlx.SliceStride(xPairs, []int32{0, 0, 0, 0, 0}, []int32{B, L, nheads, halfDim, 1}, []int32{1, 1, 1, 1, 1}) - xImag := mlx.SliceStride(xPairs, []int32{0, 0, 0, 0, 1}, []int32{B, L, nheads, halfDim, 2}, []int32{1, 1, 1, 1, 1}) - xReal = mlx.Squeeze(xReal, 4) - xImag = mlx.Squeeze(xImag, 4) - - freqReal := mlx.SliceStride(freqsExp, []int32{0, 0, 0, 0, 0}, []int32{1, L, 1, halfDim, 1}, []int32{1, 1, 1, 1, 1}) - freqImag := mlx.SliceStride(freqsExp, []int32{0, 0, 0, 0, 1}, []int32{1, L, 1, halfDim, 2}, []int32{1, 1, 1, 1, 1}) - freqReal = mlx.Squeeze(freqReal, 4) - freqImag = mlx.Squeeze(freqImag, 4) - - // Complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i - outReal := mlx.Sub(mlx.Mul(xReal, freqReal), mlx.Mul(xImag, freqImag)) - outImag := mlx.Add(mlx.Mul(xReal, freqImag), mlx.Mul(xImag, freqReal)) - - // Interleave back - outReal = mlx.ExpandDims(outReal, 4) - outImag = mlx.ExpandDims(outImag, 4) - out := mlx.Concatenate([]*mlx.Array{outReal, outImag}, 4) - - return mlx.Reshape(out, B, L, nheads, headDim) -} - -// MLP implements GELU MLP (not GEGLU) -type MLP struct { - ProjWeight *mlx.Array - ProjBias *mlx.Array - OutWeight *mlx.Array - OutBias *mlx.Array -} - -// newMLP creates a GELU MLP -func newMLP(weights *safetensors.ModelWeights, prefix string) (*MLP, error) { - projWeight, _ := weights.Get(prefix + ".net.0.proj.weight") - projBias, _ := weights.Get(prefix + ".net.0.proj.bias") - outWeight, _ := weights.Get(prefix + ".net.2.weight") - outBias, _ := weights.Get(prefix + ".net.2.bias") - - return &MLP{ - ProjWeight: mlx.Transpose(projWeight, 1, 0), - ProjBias: projBias, - OutWeight: mlx.Transpose(outWeight, 1, 0), - OutBias: outBias, - }, nil -} - -// Forward applies GELU MLP -func (m *MLP) Forward(x *mlx.Array) *mlx.Array { - shape := x.Shape() - B := shape[0] - L := shape[1] - D := shape[2] - - xFlat := mlx.Reshape(x, B*L, D) - h := mlx.Add(mlx.Linear(xFlat, m.ProjWeight), m.ProjBias) - h = geluApprox(h) - h = mlx.Add(mlx.Linear(h, m.OutWeight), m.OutBias) - return mlx.Reshape(h, B, L, m.OutBias.Dim(0)) -} - -// geluApprox implements approximate GELU -func geluApprox(x *mlx.Array) *mlx.Array { - sqrt2OverPi := float32(math.Sqrt(2.0 / math.Pi)) - x3 := mlx.Mul(mlx.Mul(x, x), x) - inner := mlx.Add(x, mlx.MulScalar(x3, 0.044715)) - inner = mlx.MulScalar(inner, sqrt2OverPi) - return mlx.Mul(mlx.MulScalar(x, 0.5), mlx.AddScalar(mlx.Tanh(inner), 1.0)) -} - -// TransformerBlock is a single dual-stream transformer block -type TransformerBlock struct { - Attention *JointAttention - ImgMLP *MLP - TxtMLP *MLP - - ImgModWeight *mlx.Array - ImgModBias *mlx.Array - TxtModWeight *mlx.Array - TxtModBias *mlx.Array - - HiddenDim int32 - NormEps float32 -} - -// newTransformerBlock creates a transformer block -func newTransformerBlock(weights *safetensors.ModelWeights, prefix string, cfg *TransformerConfig) (*TransformerBlock, error) { - attn, err := newJointAttention(weights, prefix, cfg) - if err != nil { - return nil, err - } - - imgMLP, _ := newMLP(weights, prefix+".img_mlp") - txtMLP, _ := newMLP(weights, prefix+".txt_mlp") - - imgModWeight, _ := weights.Get(prefix + ".img_mod.1.weight") - imgModBias, _ := weights.Get(prefix + ".img_mod.1.bias") - txtModWeight, _ := weights.Get(prefix + ".txt_mod.1.weight") - txtModBias, _ := weights.Get(prefix + ".txt_mod.1.bias") - - return &TransformerBlock{ - Attention: attn, - ImgMLP: imgMLP, - TxtMLP: txtMLP, - ImgModWeight: mlx.Transpose(imgModWeight, 1, 0), - ImgModBias: imgModBias, - TxtModWeight: mlx.Transpose(txtModWeight, 1, 0), - TxtModBias: txtModBias, - HiddenDim: cfg.HiddenDim, - NormEps: cfg.NormEps, - }, nil -} - -// Forward applies the transformer block -func (tb *TransformerBlock) Forward(img, txt, temb *mlx.Array, imgFreqs, txtFreqs *mlx.Array) (*mlx.Array, *mlx.Array) { - // Compute modulation: silu(temb) -> linear -> [B, 6*D] - siluT := mlx.SiLU(temb) - imgMod := mlx.Add(mlx.Linear(siluT, tb.ImgModWeight), tb.ImgModBias) - txtMod := mlx.Add(mlx.Linear(siluT, tb.TxtModWeight), tb.TxtModBias) - - // Split into 6 parts: shift1, scale1, gate1, shift2, scale2, gate2 - imgModParts := splitMod6(imgMod, tb.HiddenDim) - txtModParts := splitMod6(txtMod, tb.HiddenDim) - - // Pre-attention: norm + modulate - imgNorm := layerNormNoAffine(img, tb.NormEps) - imgNorm = mlx.Add(mlx.Mul(imgNorm, mlx.AddScalar(imgModParts[1], 1.0)), imgModParts[0]) - - txtNorm := layerNormNoAffine(txt, tb.NormEps) - txtNorm = mlx.Add(mlx.Mul(txtNorm, mlx.AddScalar(txtModParts[1], 1.0)), txtModParts[0]) - - // Joint attention - attnImg, attnTxt := tb.Attention.Forward(imgNorm, txtNorm, imgFreqs, txtFreqs) - - // Residual with gate - img = mlx.Add(img, mlx.Mul(imgModParts[2], attnImg)) - txt = mlx.Add(txt, mlx.Mul(txtModParts[2], attnTxt)) - - // Pre-MLP: norm + modulate - imgNorm2 := layerNormNoAffine(img, tb.NormEps) - imgNorm2 = mlx.Add(mlx.Mul(imgNorm2, mlx.AddScalar(imgModParts[4], 1.0)), imgModParts[3]) - - txtNorm2 := layerNormNoAffine(txt, tb.NormEps) - txtNorm2 = mlx.Add(mlx.Mul(txtNorm2, mlx.AddScalar(txtModParts[4], 1.0)), txtModParts[3]) - - // MLP - mlpImg := tb.ImgMLP.Forward(imgNorm2) - mlpTxt := tb.TxtMLP.Forward(txtNorm2) - - // Residual with gate - img = mlx.Add(img, mlx.Mul(imgModParts[5], mlpImg)) - txt = mlx.Add(txt, mlx.Mul(txtModParts[5], mlpTxt)) - - return img, txt -} - -// splitMod6 splits modulation into 6 parts each [B, 1, D] -func splitMod6(mod *mlx.Array, hiddenDim int32) []*mlx.Array { - shape := mod.Shape() - B := shape[0] - parts := make([]*mlx.Array, 6) - for i := int32(0); i < 6; i++ { - part := mlx.Slice(mod, []int32{0, i * hiddenDim}, []int32{B, (i + 1) * hiddenDim}) - parts[i] = mlx.ExpandDims(part, 1) - } - return parts -} - -// layerNormNoAffine applies layer norm without learnable parameters -func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array { - ndim := x.Ndim() - lastAxis := ndim - 1 - mean := mlx.Mean(x, lastAxis, true) - xCentered := mlx.Sub(x, mean) - variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true) - return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps))) -} - -// Transformer is the full Qwen-Image transformer model -type Transformer struct { - Config *TransformerConfig - - ImgIn *mlx.Array - ImgInBias *mlx.Array - TxtIn *mlx.Array - TxtInBias *mlx.Array - TxtNorm *mlx.Array - - TEmbed *TimestepEmbedder - Layers []*TransformerBlock - - NormOutWeight *mlx.Array - NormOutBias *mlx.Array - ProjOut *mlx.Array - ProjOutBias *mlx.Array -} - -// Load loads the transformer from a directory -func (m *Transformer) Load(path string) error { - fmt.Println("Loading Qwen-Image transformer...") - - cfg := defaultTransformerConfig() - m.Config = cfg - - weights, err := safetensors.LoadModelWeights(path) - if err != nil { - return fmt.Errorf("weights: %w", err) - } - - // Bulk load all weights as bf16 - fmt.Print(" Loading weights as bf16... ") - if err := weights.Load(mlx.DtypeBFloat16); err != nil { - return fmt.Errorf("load weights: %w", err) - } - fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024)) - - fmt.Print(" Loading input projections... ") - imgIn, _ := weights.Get("img_in.weight") - imgInBias, _ := weights.Get("img_in.bias") - txtIn, _ := weights.Get("txt_in.weight") - txtInBias, _ := weights.Get("txt_in.bias") - txtNorm, _ := weights.Get("txt_norm.weight") - m.ImgIn = mlx.Transpose(imgIn, 1, 0) - m.ImgInBias = imgInBias - m.TxtIn = mlx.Transpose(txtIn, 1, 0) - m.TxtInBias = txtInBias - m.TxtNorm = txtNorm - fmt.Println("✓") - - fmt.Print(" Loading timestep embedder... ") - m.TEmbed, err = newTimestepEmbedder(weights) - if err != nil { - return fmt.Errorf("timestep embedder: %w", err) - } - fmt.Println("✓") - - m.Layers = make([]*TransformerBlock, cfg.NLayers) - for i := int32(0); i < cfg.NLayers; i++ { - fmt.Printf("\r Loading transformer layers... %d/%d", i+1, cfg.NLayers) - prefix := fmt.Sprintf("transformer_blocks.%d", i) - m.Layers[i], err = newTransformerBlock(weights, prefix, cfg) - if err != nil { - return fmt.Errorf("layer %d: %w", i, err) - } - } - fmt.Printf("\r Loading transformer layers... ✓ [%d blocks] \n", cfg.NLayers) - - fmt.Print(" Loading output layers... ") - normOutWeight, _ := weights.Get("norm_out.linear.weight") - normOutBias, _ := weights.Get("norm_out.linear.bias") - projOut, _ := weights.Get("proj_out.weight") - projOutBias, _ := weights.Get("proj_out.bias") - m.NormOutWeight = mlx.Transpose(normOutWeight, 1, 0) - m.NormOutBias = normOutBias - m.ProjOut = mlx.Transpose(projOut, 1, 0) - m.ProjOutBias = projOutBias - fmt.Println("✓") - - weights.ReleaseAll() - return nil -} - -// LoadFromPath is a convenience function to load transformer from path -func LoadTransformerFromPath(path string) (*Transformer, error) { - m := &Transformer{} - if err := m.Load(filepath.Join(path, "transformer")); err != nil { - return nil, err - } - return m, nil -} - -// Forward runs the transformer -// img: [B, L_img, in_channels] patchified latents -// txt: [B, L_txt, joint_attention_dim] text embeddings -// t: [B] timesteps (0-1) -// imgFreqs, txtFreqs: RoPE frequencies -func (tr *Transformer) Forward(img, txt, t *mlx.Array, imgFreqs, txtFreqs *mlx.Array) *mlx.Array { - imgShape := img.Shape() - B := imgShape[0] - Limg := imgShape[1] - - txtShape := txt.Shape() - Ltxt := txtShape[1] - - // Timestep embedding - temb := tr.TEmbed.Forward(t) - - // Project image: [B, L, in_channels] -> [B, L, hidden_dim] - imgFlat := mlx.Reshape(img, B*Limg, tr.Config.InChannels) - imgH := mlx.Add(mlx.Linear(imgFlat, tr.ImgIn), tr.ImgInBias) - imgH = mlx.Reshape(imgH, B, Limg, tr.Config.HiddenDim) - - // Project text: RMSNorm then linear - txtFlat := mlx.Reshape(txt, B*Ltxt, tr.Config.JointAttentionDim) - txtNormed := mlx.RMSNorm(txtFlat, tr.TxtNorm, 1e-6) - txtH := mlx.Add(mlx.Linear(txtNormed, tr.TxtIn), tr.TxtInBias) - txtH = mlx.Reshape(txtH, B, Ltxt, tr.Config.HiddenDim) - - for _, layer := range tr.Layers { - imgH, txtH = layer.Forward(imgH, txtH, temb, imgFreqs, txtFreqs) - } - - // Final norm with modulation (AdaLayerNormContinuous) - // Python: scale, shift = torch.chunk(emb, 2, dim=1) - finalMod := mlx.Add(mlx.Linear(mlx.SiLU(temb), tr.NormOutWeight), tr.NormOutBias) - modShape := finalMod.Shape() - halfDim := modShape[1] / 2 - scale := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, 0}, []int32{B, halfDim}), 1) - shift := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, halfDim}, []int32{B, modShape[1]}), 1) - - imgH = layerNormNoAffine(imgH, tr.Config.NormEps) - imgH = mlx.Add(mlx.Mul(imgH, mlx.AddScalar(scale, 1.0)), shift) - - // Final projection: [B, L, hidden_dim] -> [B, L, patch_size^2 * out_channels] - imgFlat = mlx.Reshape(imgH, B*Limg, tr.Config.HiddenDim) - out := mlx.Add(mlx.Linear(imgFlat, tr.ProjOut), tr.ProjOutBias) - - outChannels := tr.Config.PatchSize * tr.Config.PatchSize * tr.Config.OutChannels - return mlx.Reshape(out, B, Limg, outChannels) -} - -// ForwardWithCache runs the transformer with layer caching for speedup. -// Based on DeepCache (CVPR 2024) / Learning-to-Cache (NeurIPS 2024): -// shallow layers change little between denoising steps, so we cache their -// outputs and reuse them on non-refresh steps. -// -// stepCache: cache for layer outputs (use cache.NewStepCache(cacheLayers)) -// step: current denoising step (0-indexed) -// cacheInterval: refresh cache every N steps (e.g., 3) -// cacheLayers: number of shallow layers to cache (e.g., 15) -func (tr *Transformer) ForwardWithCache( - img, txt, t *mlx.Array, - imgFreqs, txtFreqs *mlx.Array, - stepCache *cache.StepCache, - step, cacheInterval, cacheLayers int, -) *mlx.Array { - imgShape := img.Shape() - B := imgShape[0] - Limg := imgShape[1] - - txtShape := txt.Shape() - Ltxt := txtShape[1] - - // Timestep embedding - temb := tr.TEmbed.Forward(t) - - // Project image: [B, L, in_channels] -> [B, L, hidden_dim] - imgFlat := mlx.Reshape(img, B*Limg, tr.Config.InChannels) - imgH := mlx.Add(mlx.Linear(imgFlat, tr.ImgIn), tr.ImgInBias) - imgH = mlx.Reshape(imgH, B, Limg, tr.Config.HiddenDim) - - // Project text: RMSNorm then linear - txtFlat := mlx.Reshape(txt, B*Ltxt, tr.Config.JointAttentionDim) - txtNormed := mlx.RMSNorm(txtFlat, tr.TxtNorm, 1e-6) - txtH := mlx.Add(mlx.Linear(txtNormed, tr.TxtIn), tr.TxtInBias) - txtH = mlx.Reshape(txtH, B, Ltxt, tr.Config.HiddenDim) - - // Check if we should refresh the cache - refreshCache := stepCache.ShouldRefresh(step, cacheInterval) - - for i, layer := range tr.Layers { - if i < cacheLayers && !refreshCache && stepCache.Get(i) != nil { - // Use cached outputs for shallow layers - imgH = stepCache.Get(i) - txtH = stepCache.Get2(i) - } else { - // Compute layer - imgH, txtH = layer.Forward(imgH, txtH, temb, imgFreqs, txtFreqs) - // Cache shallow layers on refresh steps - if i < cacheLayers && refreshCache { - stepCache.Set(i, imgH) - stepCache.Set2(i, txtH) - } - } - } - - // Final norm with modulation (AdaLayerNormContinuous) - finalMod := mlx.Add(mlx.Linear(mlx.SiLU(temb), tr.NormOutWeight), tr.NormOutBias) - modShape := finalMod.Shape() - halfDim := modShape[1] / 2 - scale := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, 0}, []int32{B, halfDim}), 1) - shift := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, halfDim}, []int32{B, modShape[1]}), 1) - - imgH = layerNormNoAffine(imgH, tr.Config.NormEps) - imgH = mlx.Add(mlx.Mul(imgH, mlx.AddScalar(scale, 1.0)), shift) - - // Final projection: [B, L, hidden_dim] -> [B, L, patch_size^2 * out_channels] - imgFlat = mlx.Reshape(imgH, B*Limg, tr.Config.HiddenDim) - out := mlx.Add(mlx.Linear(imgFlat, tr.ProjOut), tr.ProjOutBias) - - outChannels := tr.Config.PatchSize * tr.Config.PatchSize * tr.Config.OutChannels - return mlx.Reshape(out, B, Limg, outChannels) -} - -// RoPECache holds precomputed RoPE frequencies -type RoPECache struct { - ImgFreqs *mlx.Array // [L_img, head_dim] - TxtFreqs *mlx.Array // [L_txt, head_dim] -} - -// PrepareRoPE computes RoPE for image and text sequences -// This matches Python's QwenEmbedRope with scale_rope=True -func PrepareRoPE(imgH, imgW int32, txtLen int32, axesDims []int32) *RoPECache { - theta := float64(10000) - maxIdx := int32(4096) - - // Compute base frequencies for each axis dimension - freqsT := ComputeAxisFreqs(axesDims[0], theta) - freqsH := ComputeAxisFreqs(axesDims[1], theta) - freqsW := ComputeAxisFreqs(axesDims[2], theta) - - // Build frequency lookup tables - posFreqsT := MakeFreqTable(maxIdx, freqsT, false) - posFreqsH := MakeFreqTable(maxIdx, freqsH, false) - posFreqsW := MakeFreqTable(maxIdx, freqsW, false) - negFreqsH := MakeFreqTable(maxIdx, freqsH, true) - negFreqsW := MakeFreqTable(maxIdx, freqsW, true) - - // Image frequencies with scale_rope=True - imgLen := imgH * imgW - headDim := int32(len(freqsT)+len(freqsH)+len(freqsW)) * 2 - imgFreqsData := make([]float32, imgLen*headDim) - - hHalf := imgH / 2 - wHalf := imgW / 2 - - idx := int32(0) - for y := int32(0); y < imgH; y++ { - for x := int32(0); x < imgW; x++ { - // Frame = 0 - for i := 0; i < len(freqsT)*2; i++ { - imgFreqsData[idx+int32(i)] = posFreqsT[0][i] - } - idx += int32(len(freqsT) * 2) - - // Height: scale_rope pattern - hNegCount := imgH - hHalf - if y < hNegCount { - negTableIdx := maxIdx - hNegCount + y - for i := 0; i < len(freqsH)*2; i++ { - imgFreqsData[idx+int32(i)] = negFreqsH[negTableIdx][i] - } - } else { - posIdx := y - hNegCount - for i := 0; i < len(freqsH)*2; i++ { - imgFreqsData[idx+int32(i)] = posFreqsH[posIdx][i] - } - } - idx += int32(len(freqsH) * 2) - - // Width: scale_rope pattern - wNegCount := imgW - wHalf - if x < wNegCount { - negTableIdx := maxIdx - wNegCount + x - for i := 0; i < len(freqsW)*2; i++ { - imgFreqsData[idx+int32(i)] = negFreqsW[negTableIdx][i] - } - } else { - posIdx := x - wNegCount - for i := 0; i < len(freqsW)*2; i++ { - imgFreqsData[idx+int32(i)] = posFreqsW[posIdx][i] - } - } - idx += int32(len(freqsW) * 2) - } - } - - imgFreqs := mlx.NewArray(imgFreqsData, []int32{imgLen, headDim}) - imgFreqs = mlx.ToBFloat16(imgFreqs) - - // Text frequencies - maxVidIdx := max(hHalf, wHalf) - txtFreqsData := make([]float32, txtLen*headDim) - - idx = 0 - for t := int32(0); t < txtLen; t++ { - pos := maxVidIdx + t - for i := 0; i < len(freqsT)*2; i++ { - txtFreqsData[idx+int32(i)] = posFreqsT[pos][i] - } - idx += int32(len(freqsT) * 2) - for i := 0; i < len(freqsH)*2; i++ { - txtFreqsData[idx+int32(i)] = posFreqsH[pos][i] - } - idx += int32(len(freqsH) * 2) - for i := 0; i < len(freqsW)*2; i++ { - txtFreqsData[idx+int32(i)] = posFreqsW[pos][i] - } - idx += int32(len(freqsW) * 2) - } - - txtFreqs := mlx.NewArray(txtFreqsData, []int32{txtLen, headDim}) - txtFreqs = mlx.ToBFloat16(txtFreqs) - - return &RoPECache{ - ImgFreqs: imgFreqs, - TxtFreqs: txtFreqs, - } -} - -// ComputeAxisFreqs computes RoPE base frequencies for a given dimension. -func ComputeAxisFreqs(dim int32, theta float64) []float64 { - halfDim := dim / 2 - freqs := make([]float64, halfDim) - for i := int32(0); i < halfDim; i++ { - freqs[i] = 1.0 / math.Pow(theta, float64(i)/float64(halfDim)) - } - return freqs -} - -// MakeFreqTable builds a table of cos/sin values for RoPE positions. -func MakeFreqTable(maxIdx int32, baseFreqs []float64, negative bool) [][]float32 { - table := make([][]float32, maxIdx) - for idx := int32(0); idx < maxIdx; idx++ { - var pos float64 - if negative { - pos = float64(-maxIdx + int32(idx)) - } else { - pos = float64(idx) - } - - row := make([]float32, len(baseFreqs)*2) - for i, f := range baseFreqs { - angle := pos * f - row[i*2] = float32(math.Cos(angle)) - row[i*2+1] = float32(math.Sin(angle)) - } - table[idx] = row - } - return table -} - -func max(a, b int32) int32 { - if a > b { - return a - } - return b -} - -// PackLatents converts [B, C, H, W] to [B, L, C*4] patches -func PackLatents(latents *mlx.Array, patchSize int32) *mlx.Array { - shape := latents.Shape() - B := shape[0] - C := shape[1] - H := shape[2] - W := shape[3] - - pH := H / patchSize - pW := W / patchSize - - // [B, C, H, W] -> [B, C, pH, 2, pW, 2] - x := mlx.Reshape(latents, B, C, pH, patchSize, pW, patchSize) - // -> [B, pH, pW, C, 2, 2] - x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5) - // -> [B, pH*pW, C*4] - return mlx.Reshape(x, B, pH*pW, C*patchSize*patchSize) -} - -// UnpackLatents converts [B, L, C*4] back to [B, C, 1, H, W] (5D for VAE) -func UnpackLatents(patches *mlx.Array, H, W, patchSize int32) *mlx.Array { - shape := patches.Shape() - B := shape[0] - channels := shape[2] / (patchSize * patchSize) - - pH := H / patchSize - pW := W / patchSize - - // [B, L, C*4] -> [B, pH, pW, C, 2, 2] - x := mlx.Reshape(patches, B, pH, pW, channels, patchSize, patchSize) - // -> [B, C, pH, 2, pW, 2] - x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5) - // -> [B, C, H, W] - x = mlx.Reshape(x, B, channels, pH*patchSize, pW*patchSize) - // Add temporal dimension for VAE: [B, C, 1, H, W] - return mlx.ExpandDims(x, 2) -} diff --git a/x/imagegen/models/qwen_image/transformer_test.go b/x/imagegen/models/qwen_image/transformer_test.go deleted file mode 100644 index 5eef53b1d..000000000 --- a/x/imagegen/models/qwen_image/transformer_test.go +++ /dev/null @@ -1,119 +0,0 @@ -//go:build mlx - -package qwen_image - -import ( - "math" - "os" - "testing" - - "github.com/ollama/ollama/x/imagegen/mlx" -) - -// TestTransformerConfig tests configuration invariants. -func TestTransformerConfig(t *testing.T) { - cfg := defaultTransformerConfig() - - // Property: hidden_dim = n_heads * head_dim - if cfg.HiddenDim != cfg.NHeads*cfg.HeadDim { - t.Errorf("hidden_dim != n_heads * head_dim: %d != %d * %d", - cfg.HiddenDim, cfg.NHeads, cfg.HeadDim) - } - - // Property: axes_dims_rope sums to head_dim - var ropeSum int32 - for _, d := range cfg.AxesDimsRope { - ropeSum += d - } - if ropeSum != cfg.HeadDim { - t.Errorf("axes_dims_rope sum != head_dim: %d != %d", ropeSum, cfg.HeadDim) - } - - // Property: in_channels = out_channels * patch_size^2 - expectedIn := cfg.OutChannels * cfg.PatchSize * cfg.PatchSize - if cfg.InChannels != expectedIn { - t.Errorf("in_channels != out_channels * patch_size^2: %d != %d", cfg.InChannels, expectedIn) - } -} - -// TestTransformerRoPE tests RoPE frequency computation produces valid values. -func TestTransformerRoPE(t *testing.T) { - cfg := defaultTransformerConfig() - - // Test with small image dimensions - imgH, imgW := int32(4), int32(4) // 4x4 latent = 16 patches - txtLen := int32(5) - - ropeCache := PrepareRoPE(imgH, imgW, txtLen, cfg.AxesDimsRope) - mlx.Eval(ropeCache.ImgFreqs, ropeCache.TxtFreqs) - - // Verify shapes: [seq_len, head_dim] - imgSeqLen := imgH * imgW - if ropeCache.ImgFreqs.Shape()[0] != imgSeqLen { - t.Errorf("ImgFreqs seq_len: got %d, want %d", ropeCache.ImgFreqs.Shape()[0], imgSeqLen) - } - if ropeCache.ImgFreqs.Shape()[1] != cfg.HeadDim { - t.Errorf("ImgFreqs head_dim: got %d, want %d", ropeCache.ImgFreqs.Shape()[1], cfg.HeadDim) - } - - if ropeCache.TxtFreqs.Shape()[0] != txtLen { - t.Errorf("TxtFreqs seq_len: got %d, want %d", ropeCache.TxtFreqs.Shape()[0], txtLen) - } - - // Verify values are finite - imgData := ropeCache.ImgFreqs.Data() - for i := 0; i < min(100, len(imgData)); i++ { - if math.IsNaN(float64(imgData[i])) || math.IsInf(float64(imgData[i]), 0) { - t.Errorf("ImgFreqs[%d] not finite: %v", i, imgData[i]) - break - } - } -} - -// TestTransformerForward tests full forward pass (integration test). -// Skips if model weights are not available. -func TestTransformerForward(t *testing.T) { - weightsPath := "../../../weights/Qwen-Image-2512/transformer" - if _, err := os.Stat(weightsPath); os.IsNotExist(err) { - t.Skip("Skipping: model weights not found at " + weightsPath) - } - - transformer := &Transformer{} - if err := transformer.Load(weightsPath); err != nil { - t.Fatalf("Failed to load transformer: %v", err) - } - mlx.Keep(mlx.Collect(transformer)...) - cfg := transformer.Config - - // Small test inputs - batchSize := int32(1) - imgH, imgW := int32(4), int32(4) - imgSeqLen := imgH * imgW - txtSeqLen := int32(5) - - hiddenStates := mlx.RandomNormal([]int32{batchSize, imgSeqLen, cfg.InChannels}, 0) - encoderHiddenStates := mlx.RandomNormal([]int32{batchSize, txtSeqLen, cfg.JointAttentionDim}, 0) - timestep := mlx.NewArray([]float32{0.5}, []int32{batchSize}) - - ropeCache := PrepareRoPE(imgH, imgW, txtSeqLen, cfg.AxesDimsRope) - - // Forward pass - out := transformer.Forward(hiddenStates, encoderHiddenStates, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs) - mlx.Eval(out) - - // Verify output shape: [batch, img_seq_len, in_channels] - wantShape := []int32{batchSize, imgSeqLen, cfg.InChannels} - gotShape := out.Shape() - if gotShape[0] != wantShape[0] || gotShape[1] != wantShape[1] || gotShape[2] != wantShape[2] { - t.Errorf("output shape: got %v, want %v", gotShape, wantShape) - } - - // Verify output is finite - outData := out.Data() - for i := 0; i < min(100, len(outData)); i++ { - if math.IsNaN(float64(outData[i])) || math.IsInf(float64(outData[i]), 0) { - t.Errorf("output[%d] not finite: %v", i, outData[i]) - break - } - } -} diff --git a/x/imagegen/models/qwen_image/vae.go b/x/imagegen/models/qwen_image/vae.go deleted file mode 100644 index e1c7f5255..000000000 --- a/x/imagegen/models/qwen_image/vae.go +++ /dev/null @@ -1,854 +0,0 @@ -//go:build mlx - -package qwen_image - -import ( - "fmt" - "math" - "path/filepath" - - "github.com/ollama/ollama/x/imagegen/mlx" - "github.com/ollama/ollama/x/imagegen/safetensors" -) - -// VAEConfig holds Qwen-Image VAE configuration -type VAEConfig struct { - ZDim int32 `json:"z_dim"` // 16 - BaseDim int32 `json:"base_dim"` // 96 - DimMult []int32 `json:"dim_mult"` // [1, 2, 4, 4] - NumResBlocks int32 `json:"num_res_blocks"` // 2 - LatentsMean []float32 `json:"latents_mean"` // 16 values - LatentsStd []float32 `json:"latents_std"` // 16 values - TemperalDownsample []bool `json:"temperal_downsample"` // [false, true, true] -} - -// defaultVAEConfig returns config for Qwen-Image VAE -func defaultVAEConfig() *VAEConfig { - return &VAEConfig{ - ZDim: 16, - BaseDim: 96, - DimMult: []int32{1, 2, 4, 4}, - NumResBlocks: 2, - LatentsMean: []float32{ - -0.7571, -0.7089, -0.9113, 0.1075, - -0.1745, 0.9653, -0.1517, 1.5508, - 0.4134, -0.0715, 0.5517, -0.3632, - -0.1922, -0.9497, 0.2503, -0.2921, - }, - LatentsStd: []float32{ - 2.8184, 1.4541, 2.3275, 2.6558, - 1.2196, 1.7708, 2.6052, 2.0743, - 3.2687, 2.1526, 2.8652, 1.5579, - 1.6382, 1.1253, 2.8251, 1.916, - }, - TemperalDownsample: []bool{false, true, true}, - } -} - -// CausalConv3d is a causal 3D convolution (for temporal causality) -type CausalConv3d struct { - Weight *mlx.Array - Bias *mlx.Array - BiasReshaped *mlx.Array // [1, C, 1, 1, 1] - KernelT int32 -} - -// newCausalConv3d creates a 3D causal conv -func newCausalConv3d(weights *safetensors.ModelWeights, prefix string) (*CausalConv3d, error) { - weight, err := weights.Get(prefix + ".weight") - if err != nil { - return nil, fmt.Errorf("weight not found: %s", prefix) - } - bias, _ := weights.Get(prefix + ".bias") - - kernelT := weight.Shape()[2] - outC := weight.Shape()[0] - - var biasReshaped *mlx.Array - if bias != nil { - biasReshaped = mlx.Reshape(bias, 1, outC, 1, 1, 1) - } - - return &CausalConv3d{ - Weight: weight, - Bias: bias, - BiasReshaped: biasReshaped, - KernelT: kernelT, - }, nil -} - -// Forward applies causal 3D convolution -// x: [B, T, H, W, C] (channels-last, MLX format) -func (c *CausalConv3d) Forward(x *mlx.Array) *mlx.Array { - shape := c.Weight.Shape() // PyTorch format: [O, I, kT, kH, kW] - kernelT := shape[2] - kernelH := shape[3] - kernelW := shape[4] - - // Causal temporal padding, same spatial padding - // Input is channels-last: [B, T, H, W, C] - padT := kernelT - 1 - padH := kernelH / 2 - padW := kernelW / 2 - - // Stage 1: Pad - { - x = pad3DChannelsLast(x, padT, 0, padH, padH, padW, padW) - mlx.Eval(x) - } - - // Stage 2: Conv + bias - var out *mlx.Array - { - prev := x - weight := mlx.Transpose(c.Weight, 0, 2, 3, 4, 1) - out = mlx.Conv3d(x, weight, 1, 1, 1, 0, 0, 0) - if c.Bias != nil { - bias := mlx.Reshape(c.Bias, 1, 1, 1, 1, c.Bias.Dim(0)) - out = mlx.Add(out, bias) - } - prev.Free() - mlx.Eval(out) - } - - return out -} - -// RMSNorm3D applies RMS normalization over channels -// Works with channels-last [B, T, H, W, C] format -type RMSNorm3D struct { - Gamma *mlx.Array // [1, 1, 1, 1, C] for broadcasting -} - -// newRMSNorm3D creates an RMS norm -func newRMSNorm3D(weights *safetensors.ModelWeights, prefix string, dim int32) (*RMSNorm3D, error) { - gamma, err := weights.Get(prefix + ".gamma") - if err != nil { - return nil, err - } - // Reshape for channels-last broadcasting: [1, 1, 1, 1, C] - gamma = mlx.Reshape(gamma, 1, 1, 1, 1, gamma.Dim(0)) - return &RMSNorm3D{Gamma: gamma}, nil -} - -// Forward applies RMS norm to channels-last input [B, T, H, W, C] -func (n *RMSNorm3D) Forward(x *mlx.Array) *mlx.Array { - // RMSNorm: x * rsqrt(mean(x^2) + eps) * gamma - normalized := mlx.RMSNormNoWeight(x, 1e-6) - return mlx.Mul(normalized, n.Gamma) -} - -// ResBlock is a residual block with RMS norm and causal convs -type ResBlock struct { - Norm1 *RMSNorm3D - Conv1 *CausalConv3d - Norm2 *RMSNorm3D - Conv2 *CausalConv3d - Shortcut *CausalConv3d -} - -// newResBlock creates a residual block -func newResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*ResBlock, error) { - norm1, err := newRMSNorm3D(weights, prefix+".norm1", inDim) - if err != nil { - return nil, err - } - conv1, err := newCausalConv3d(weights, prefix+".conv1") - if err != nil { - return nil, err - } - norm2, err := newRMSNorm3D(weights, prefix+".norm2", outDim) - if err != nil { - return nil, err - } - conv2, err := newCausalConv3d(weights, prefix+".conv2") - if err != nil { - return nil, err - } - - var shortcut *CausalConv3d - if inDim != outDim { - shortcut, err = newCausalConv3d(weights, prefix+".conv_shortcut") - if err != nil { - return nil, err - } - } - - return &ResBlock{ - Norm1: norm1, - Conv1: conv1, - Norm2: norm2, - Conv2: conv2, - Shortcut: shortcut, - }, nil -} - -// Forward applies the residual block -func (r *ResBlock) Forward(x *mlx.Array) *mlx.Array { - // Use h as working variable, keep x intact for residual (caller will free x) - // Conv handles its own pools, so we just need pools for non-conv operations - var h *mlx.Array - - // Keep x so it survives Eval() cleanup - needed for residual connection - mlx.Keep(x) - - // Stage 1: norm1 + silu - { - h = r.Norm1.Forward(x) - h = silu3D(h) - mlx.Eval(h) - } - - // Stage 2: conv1 (handles its own pools) - { - prev := h - h = r.Conv1.Forward(h) - prev.Free() - } - - // Stage 3: norm2 + silu - { - prev := h - h = r.Norm2.Forward(h) - h = silu3D(h) - prev.Free() - mlx.Eval(h) - } - - // Stage 4: conv2 (handles its own pools) - { - prev := h - h = r.Conv2.Forward(h) - prev.Free() - } - - // Residual connection (shortcut handles its own pools if present) - if r.Shortcut != nil { - shortcut := r.Shortcut.Forward(x) - h = mlx.Add(h, shortcut) - mlx.Eval(h) - } else { - h = mlx.Add(h, x) - mlx.Eval(h) - } - - return h -} - -// AttentionBlock is a 2D attention block -type AttentionBlock struct { - Norm *RMSNorm3D - ToQKV *mlx.Array - ToQKVBias *mlx.Array - Proj *mlx.Array - ProjBias *mlx.Array - Dim int32 -} - -// newAttentionBlock creates an attention block -func newAttentionBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*AttentionBlock, error) { - norm, err := newRMSNorm3D(weights, prefix+".norm", dim) - if err != nil { - return nil, err - } - toQKV, _ := weights.Get(prefix + ".to_qkv.weight") - toQKVBias, _ := weights.Get(prefix + ".to_qkv.bias") - proj, _ := weights.Get(prefix + ".proj.weight") - projBias, _ := weights.Get(prefix + ".proj.bias") - - return &AttentionBlock{ - Norm: norm, - ToQKV: toQKV, - ToQKVBias: toQKVBias, - Proj: proj, - ProjBias: projBias, - Dim: dim, - }, nil -} - -// Forward applies 2D attention -// Input: [B, T, H, W, C] (channels-last) -func (a *AttentionBlock) Forward(x *mlx.Array) *mlx.Array { - shape := x.Shape() - B := shape[0] - T := shape[1] - H := shape[2] - W := shape[3] - C := shape[4] - - identity := x - - // Flatten to [B*T, 1, H, W, C] for norm - x = mlx.Reshape(x, B*T, 1, H, W, C) - x = a.Norm.Forward(x) - x = mlx.Reshape(x, B*T, H, W, C) - - // Flatten spatial to [B*T, H*W, C] - x = mlx.Reshape(x, B*T, H*W, C) - - // Linear to get Q, K, V: [B*T, H*W, 3*C] - // Weight is [outC, inC] or [outC, inC, 1, 1] - wShape := a.ToQKV.Shape() - var w *mlx.Array - if len(wShape) == 4 { - w = mlx.Reshape(a.ToQKV, wShape[0], wShape[1]) - } else { - w = a.ToQKV - } - w = mlx.Transpose(w, 1, 0) // [inC, outC] - - qkv := mlx.Linear(x, w) // [B*T, H*W, 3*C] - if a.ToQKVBias != nil { - qkv = mlx.Add(qkv, a.ToQKVBias) - } - qkv = mlx.Reshape(qkv, B*T, 1, H*W, 3*C) - - q := mlx.Slice(qkv, []int32{0, 0, 0, 0}, []int32{B * T, 1, H * W, C}) - k := mlx.Slice(qkv, []int32{0, 0, 0, C}, []int32{B * T, 1, H * W, 2 * C}) - v := mlx.Slice(qkv, []int32{0, 0, 0, 2 * C}, []int32{B * T, 1, H * W, 3 * C}) - - scale := float32(1.0 / math.Sqrt(float64(C))) - out := mlx.ScaledDotProductAttention(q, k, v, scale, false) - - // out: [B*T, 1, H*W, C] - out = mlx.Reshape(out, B*T, H*W, C) - - // Project back - pShape := a.Proj.Shape() - var p *mlx.Array - if len(pShape) == 4 { - p = mlx.Reshape(a.Proj, pShape[0], pShape[1]) - } else { - p = a.Proj - } - p = mlx.Transpose(p, 1, 0) // [inC, outC] - out = mlx.Linear(out, p) // [B*T, H*W, C] - if a.ProjBias != nil { - out = mlx.Add(out, a.ProjBias) - } - - out = mlx.Reshape(out, B, T, H, W, C) - return mlx.Add(out, identity) -} - -// UpBlock handles upsampling in decoder -type UpBlock struct { - ResBlocks []*ResBlock - Upsampler *Upsample -} - -// newUpBlock creates an up block -func newUpBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, upsampleMode string) (*UpBlock, error) { - resBlocks := make([]*ResBlock, numBlocks+1) - - currentDim := inDim - for i := int32(0); i <= numBlocks; i++ { - resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i) - block, err := newResBlock(weights, resPrefix, currentDim, outDim) - if err != nil { - return nil, err - } - resBlocks[i] = block - currentDim = outDim - } - - var upsampler *Upsample - if upsampleMode != "" { - upsampler = newUpsample(weights, prefix+".upsamplers.0", outDim, upsampleMode) - } - - return &UpBlock{ - ResBlocks: resBlocks, - Upsampler: upsampler, - }, nil -} - -// Forward applies up block with staged memory management -func (u *UpBlock) Forward(x *mlx.Array) *mlx.Array { - // ResBlocks handle their own pools - for _, block := range u.ResBlocks { - prev := x - x = block.Forward(x) - prev.Free() - } - - // Upsampler handles its own pools - if u.Upsampler != nil { - prev := x - x = u.Upsampler.Forward(x) - prev.Free() - } - return x -} - -// Upsample handles spatial upsampling -type Upsample struct { - Conv *mlx.Array - Bias *mlx.Array - Mode string -} - -// newUpsample creates an upsampler -func newUpsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Upsample { - conv, _ := weights.Get(prefix + ".resample.1.weight") - bias, _ := weights.Get(prefix + ".resample.1.bias") - return &Upsample{ - Conv: conv, - Bias: bias, - Mode: mode, - } -} - -// Forward applies upsampling to channels-last input [B, T, H, W, C] -// Uses staged pools to reduce peak memory during 2x upsampling -func (u *Upsample) Forward(x *mlx.Array) *mlx.Array { - shape := x.Shape() - B := shape[0] - T := shape[1] - H := shape[2] - W := shape[3] - C := shape[4] - outC := u.Conv.Shape()[0] - - // Stage 1: 2x nearest neighbor upsample - { - x = mlx.Reshape(x, B*T, H, W, C) - x = upsample2xChannelsLast(x) - mlx.Eval(x) - } - - // Stage 2: Conv + bias - { - prev := x - weight := mlx.Transpose(u.Conv, 0, 2, 3, 1) - x = conv2D3x3PaddedChannelsLast(x, weight) - if u.Bias != nil { - bias := mlx.Reshape(u.Bias, 1, 1, 1, outC) - x = mlx.Add(x, bias) - } - x = mlx.Reshape(x, B, T, H*2, W*2, outC) - prev.Free() - mlx.Eval(x) - } - - return x -} - -// MidBlock is the middle block of decoder -type MidBlock struct { - ResBlock1 *ResBlock - Attention *AttentionBlock - ResBlock2 *ResBlock -} - -// newMidBlock creates a mid block -func newMidBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*MidBlock, error) { - res1, err := newResBlock(weights, prefix+".resnets.0", dim, dim) - if err != nil { - return nil, err - } - attn, err := newAttentionBlock(weights, prefix+".attentions.0", dim) - if err != nil { - return nil, err - } - res2, err := newResBlock(weights, prefix+".resnets.1", dim, dim) - if err != nil { - return nil, err - } - - return &MidBlock{ - ResBlock1: res1, - Attention: attn, - ResBlock2: res2, - }, nil -} - -// Forward applies mid block -func (m *MidBlock) Forward(x *mlx.Array) *mlx.Array { - // Each component handles its own pools; we just free inputs - prev := x - x = m.ResBlock1.Forward(x) - prev.Free() - - prev = x - x = m.Attention.Forward(x) - prev.Free() - - prev = x - x = m.ResBlock2.Forward(x) - prev.Free() - - return x -} - -// VAEDecoder is the full VAE decoder -type VAEDecoder struct { - Config *VAEConfig - - PostQuantConv *CausalConv3d - ConvIn *CausalConv3d - MidBlock *MidBlock - UpBlocks []*UpBlock - NormOut *RMSNorm3D - ConvOut *CausalConv3d -} - -// Load loads the VAE decoder from a directory -func (m *VAEDecoder) Load(path string) error { - fmt.Println("Loading Qwen-Image VAE decoder...") - - cfg := defaultVAEConfig() - m.Config = cfg - - weights, err := safetensors.LoadModelWeights(path) - if err != nil { - return fmt.Errorf("weights: %w", err) - } - - // Bulk load all weights as bf16 - fmt.Print(" Loading weights as bf16... ") - if err := weights.Load(mlx.DtypeBFloat16); err != nil { - return fmt.Errorf("failed to load weights: %w", err) - } - fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024)) - - fmt.Print(" Loading post_quant_conv... ") - postQuantConv, err := newCausalConv3d(weights, "post_quant_conv") - if err != nil { - return err - } - m.PostQuantConv = postQuantConv - fmt.Println("✓") - - fmt.Print(" Loading conv_in... ") - convIn, err := newCausalConv3d(weights, "decoder.conv_in") - if err != nil { - return err - } - m.ConvIn = convIn - fmt.Println("✓") - - // Mid block (dim = base_dim * dim_mult[-1] = 96 * 4 = 384) - fmt.Print(" Loading mid_block... ") - midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1] - midBlock, err := newMidBlock(weights, "decoder.mid_block", midDim) - if err != nil { - return err - } - m.MidBlock = midBlock - fmt.Println("✓") - - // Up blocks (reversed dim_mult) - fmt.Print(" Loading up_blocks... ") - numUpBlocks := len(cfg.DimMult) - m.UpBlocks = make([]*UpBlock, numUpBlocks) - - dimsMult := make([]int32, numUpBlocks+1) - dimsMult[0] = cfg.DimMult[numUpBlocks-1] - for i := 0; i < numUpBlocks; i++ { - dimsMult[i+1] = cfg.DimMult[numUpBlocks-1-i] - } - - temporalUpsample := make([]bool, len(cfg.TemperalDownsample)) - for i := range cfg.TemperalDownsample { - temporalUpsample[i] = cfg.TemperalDownsample[len(cfg.TemperalDownsample)-1-i] - } - - for i := 0; i < numUpBlocks; i++ { - inDim := cfg.BaseDim * dimsMult[i] - outDim := cfg.BaseDim * dimsMult[i+1] - - if i > 0 { - inDim = inDim / 2 - } - - upsampleMode := "" - if i < numUpBlocks-1 { - if temporalUpsample[i] { - upsampleMode = "upsample3d" - } else { - upsampleMode = "upsample2d" - } - } - - prefix := fmt.Sprintf("decoder.up_blocks.%d", i) - upBlock, err := newUpBlock(weights, prefix, inDim, outDim, cfg.NumResBlocks, upsampleMode) - if err != nil { - return err - } - m.UpBlocks[i] = upBlock - } - fmt.Printf("✓ [%d blocks]\n", numUpBlocks) - - fmt.Print(" Loading output layers... ") - normOut, err := newRMSNorm3D(weights, "decoder.norm_out", cfg.BaseDim) - if err != nil { - return err - } - m.NormOut = normOut - convOut, err := newCausalConv3d(weights, "decoder.conv_out") - if err != nil { - return err - } - m.ConvOut = convOut - fmt.Println("✓") - - weights.ReleaseAll() - return nil -} - -// LoadVAEDecoderFromPath is a convenience function to load VAE from path -func LoadVAEDecoderFromPath(path string) (*VAEDecoder, error) { - m := &VAEDecoder{} - if err := m.Load(filepath.Join(path, "vae")); err != nil { - return nil, err - } - return m, nil -} - -// Decode converts latents to image -// z: [B, C, T, H, W] normalized latents -// Uses staged pools to free intermediate arrays and reduce peak memory. -func (vae *VAEDecoder) Decode(z *mlx.Array) *mlx.Array { - var x *mlx.Array - - // Stage 1a: Denormalize and transpose - { - z = vae.Denormalize(z) - // Convert from channels-first [N, C, T, H, W] to channels-last [N, T, H, W, C] - z = mlx.Contiguous(mlx.Transpose(z, 0, 2, 3, 4, 1)) - mlx.Eval(z) - } - - // Stage 1b: PostQuantConv (handles its own pools) - x = vae.PostQuantConv.Forward(z) - z.Free() - - // Stage 1c: ConvIn (handles its own pools) - { - prev := x - x = vae.ConvIn.Forward(x) - prev.Free() - } - - // Stage 2: Mid block (handles its own pools) - x = vae.MidBlock.Forward(x) - - // Stage 3: Up blocks (each handles its own pools) - for _, upBlock := range vae.UpBlocks { - x = upBlock.Forward(x) - } - - // Stage 4a: NormOut + silu - { - prev := x - x = vae.NormOut.Forward(x) - x = silu3D(x) - prev.Free() - mlx.Eval(x) - } - - // Stage 4b: ConvOut (handles its own pools) - { - prev := x - x = vae.ConvOut.Forward(x) - prev.Free() - } - - // Stage 4c: Post-processing - { - prev := x - // Clamp to [-1, 1] - x = mlx.ClipScalar(x, -1.0, 1.0, true, true) - // Convert back from channels-last to channels-first - x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3)) - prev.Free() - mlx.Eval(x) - } - - return x -} - -// Denormalize reverses the normalization applied during encoding -func (vae *VAEDecoder) Denormalize(z *mlx.Array) *mlx.Array { - shape := z.Shape() - C := shape[1] - - mean := mlx.NewArray(vae.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1}) - std := mlx.NewArray(vae.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1}) - - mean = mlx.ToBFloat16(mean) - std = mlx.ToBFloat16(std) - - return mlx.Add(mlx.Mul(z, std), mean) -} - -// Helper functions - -func silu3D(x *mlx.Array) *mlx.Array { - return mlx.Mul(x, mlx.Sigmoid(x)) -} - -// pad3DChannelsLast pads a channels-last [B, T, H, W, C] tensor -func pad3DChannelsLast(x *mlx.Array, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array { - if tBefore == 0 && tAfter == 0 && hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 { - return x - } - // Pad dims: [B before, B after, T before, T after, H before, H after, W before, W after, C before, C after] - return mlx.Pad(x, []int32{0, 0, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter, 0, 0}) -} - -func pad2D(x *mlx.Array, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array { - if hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 { - return x - } - return mlx.Pad(x, []int32{0, 0, 0, 0, hBefore, hAfter, wBefore, wAfter}) -} - -func conv2D1x1(x, weight *mlx.Array) *mlx.Array { - shape := x.Shape() - B := shape[0] - H := shape[2] - W := shape[3] - - x = mlx.Transpose(x, 0, 2, 3, 1) - x = mlx.Reshape(x, B*H*W, shape[1]) - - wShape := weight.Shape() - var w *mlx.Array - if len(wShape) == 4 { - w = mlx.Reshape(weight, wShape[0], wShape[1]) - } else { - w = weight - } - w = mlx.Transpose(w, 1, 0) - - out := mlx.Linear(x, w) - outC := w.Dim(1) - out = mlx.Reshape(out, B, H, W, outC) - return mlx.Transpose(out, 0, 3, 1, 2) -} - -func conv2D3x3Padded(x, weight *mlx.Array) *mlx.Array { - x = pad2D(x, 1, 1, 1, 1) - return conv2D(x, weight, 1, 1) -} - -func conv2D(x, w *mlx.Array, strideH, strideW int32) *mlx.Array { - x = mlx.Transpose(x, 0, 2, 3, 1) - w = mlx.Transpose(w, 0, 2, 3, 1) - - shape := x.Shape() - B := shape[0] - H := shape[1] - W := shape[2] - - wShape := w.Shape() - Cout := wShape[0] - kH := wShape[1] - kW := wShape[2] - - outH := (H-kH)/strideH + 1 - outW := (W-kW)/strideW + 1 - - patches := extractPatches2D(x, kH, kW, strideH, strideW) - wFlat := mlx.Reshape(w, Cout, -1) - patches = mlx.Reshape(patches, B*outH*outW, -1) - out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0)) - out = mlx.Reshape(out, B, outH, outW, Cout) - return mlx.Transpose(out, 0, 3, 1, 2) -} - -func extractPatches2D(x *mlx.Array, kH, kW, strideH, strideW int32) *mlx.Array { - shape := x.Shape() - B := shape[0] - H := shape[1] - W := shape[2] - C := shape[3] - - outH := (H-kH)/strideH + 1 - outW := (W-kW)/strideW + 1 - - patches := make([]*mlx.Array, outH*outW) - idx := 0 - for i := int32(0); i < outH; i++ { - for j := int32(0); j < outW; j++ { - startH := i * strideH - startW := j * strideW - patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C}) - patch = mlx.Reshape(patch, B, kH*kW*C) - patches[idx] = patch - idx++ - } - } - - for i := range patches { - patches[i] = mlx.ExpandDims(patches[i], 1) - } - stacked := mlx.Concatenate(patches, 1) - return mlx.Reshape(stacked, B, outH, outW, kH*kW*C) -} - -func upsample2x(x *mlx.Array) *mlx.Array { - shape := x.Shape() - H := shape[2] - W := shape[3] - - rowIdxData := make([]int32, H*2) - for i := int32(0); i < H; i++ { - rowIdxData[i*2] = i - rowIdxData[i*2+1] = i - } - rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2}) - - colIdxData := make([]int32, W*2) - for i := int32(0); i < W; i++ { - colIdxData[i*2] = i - colIdxData[i*2+1] = i - } - colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2}) - - x = mlx.Take(x, rowIdx, 2) - x = mlx.Take(x, colIdx, 3) - - return x -} - -// upsample2xChannelsLast upsamples channels-last input [B, H, W, C] by 2x -func upsample2xChannelsLast(x *mlx.Array) *mlx.Array { - shape := x.Shape() - H := shape[1] - W := shape[2] - - // Create repeat indices for rows - rowIdxData := make([]int32, H*2) - for i := int32(0); i < H; i++ { - rowIdxData[i*2] = i - rowIdxData[i*2+1] = i - } - rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2}) - - // Create repeat indices for columns - colIdxData := make([]int32, W*2) - for i := int32(0); i < W; i++ { - colIdxData[i*2] = i - colIdxData[i*2+1] = i - } - colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2}) - - // Take along H (axis 1) then W (axis 2) - x = mlx.Take(x, rowIdx, 1) - x = mlx.Take(x, colIdx, 2) - - return x -} - -// conv2D3x3PaddedChannelsLast applies 3x3 conv with padding to channels-last input [B, H, W, C] -// weight: [outC, kH, kW, inC] (MLX channels-last format) -func conv2D3x3PaddedChannelsLast(x, weight *mlx.Array) *mlx.Array { - // Pad spatial dims: [B, H, W, C] -> pad H and W by 1 each side - x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0}) - // Conv2d expects: input [B, H, W, inC], weight [outC, kH, kW, inC] - // stride=1, padding=0 (we already padded manually) - return mlx.Conv2d(x, weight, 1, 0) -} diff --git a/x/imagegen/models/qwen_image/vae_test.go b/x/imagegen/models/qwen_image/vae_test.go deleted file mode 100644 index f15a1134b..000000000 --- a/x/imagegen/models/qwen_image/vae_test.go +++ /dev/null @@ -1,114 +0,0 @@ -//go:build mlx - -package qwen_image - -import ( - "math" - "os" - "testing" - - "github.com/ollama/ollama/x/imagegen/mlx" -) - -// TestVAEConfig tests configuration invariants. -func TestVAEConfig(t *testing.T) { - cfg := defaultVAEConfig() - - // Property: latents_mean and latents_std have z_dim elements - if int32(len(cfg.LatentsMean)) != cfg.ZDim { - t.Errorf("latents_mean length != z_dim: %d != %d", len(cfg.LatentsMean), cfg.ZDim) - } - if int32(len(cfg.LatentsStd)) != cfg.ZDim { - t.Errorf("latents_std length != z_dim: %d != %d", len(cfg.LatentsStd), cfg.ZDim) - } - - // Property: dim_mult defines 4 stages - if len(cfg.DimMult) != 4 { - t.Errorf("dim_mult should have 4 stages: got %d", len(cfg.DimMult)) - } - - // Property: temperal_downsample has 3 elements (for 3 transitions) - if len(cfg.TemperalDownsample) != 3 { - t.Errorf("temperal_downsample should have 3 elements: got %d", len(cfg.TemperalDownsample)) - } -} - -// TestVAELatentsNormalization tests the latent denormalization values. -func TestVAELatentsNormalization(t *testing.T) { - cfg := defaultVAEConfig() - - // Verify latents_std values are all positive - for i, std := range cfg.LatentsStd { - if std <= 0 { - t.Errorf("latents_std[%d] should be positive: %v", i, std) - } - } - - // Verify values are in reasonable range (from actual model) - for i, mean := range cfg.LatentsMean { - if math.Abs(float64(mean)) > 5 { - t.Errorf("latents_mean[%d] seems too large: %v", i, mean) - } - } - for i, std := range cfg.LatentsStd { - if std > 10 { - t.Errorf("latents_std[%d] seems too large: %v", i, std) - } - } -} - -// TestVAEDecoderForward tests full forward pass (integration test). -// Skips if model weights are not available. -func TestVAEDecoderForward(t *testing.T) { - weightsPath := "../../../weights/Qwen-Image-2512/vae" - if _, err := os.Stat(weightsPath); os.IsNotExist(err) { - t.Skip("Skipping: model weights not found at " + weightsPath) - } - - vae := &VAEDecoder{} - if err := vae.Load(weightsPath); err != nil { - t.Fatalf("Failed to load VAE decoder: %v", err) - } - mlx.Keep(mlx.Collect(vae)...) - - // Small test input: [B, C, T, H, W] - // After 4 upsampling stages (2x each), H/W multiply by 16 - batchSize := int32(1) - channels := int32(16) - frames := int32(1) - latentH := int32(4) - latentW := int32(4) - - latents := mlx.RandomNormal([]int32{batchSize, channels, frames, latentH, latentW}, 0) - - // Decode - out := vae.Decode(latents) - mlx.Eval(out) - - // Verify output shape: [B, 3, T, H*16, W*16] - outShape := out.Shape() - if outShape[0] != batchSize { - t.Errorf("batch size: got %d, want %d", outShape[0], batchSize) - } - if outShape[1] != 3 { - t.Errorf("channels: got %d, want 3", outShape[1]) - } - if outShape[2] != frames { - t.Errorf("frames: got %d, want %d", outShape[2], frames) - } - expectedH := latentH * 16 // 4 stages of 2x upsampling - expectedW := latentW * 16 - if outShape[3] != expectedH || outShape[4] != expectedW { - t.Errorf("spatial dims: got [%d, %d], want [%d, %d]", - outShape[3], outShape[4], expectedH, expectedW) - } - - // Verify output is in valid range (should be clamped to [0, 1] by decode) - outData := out.Data() - for i := 0; i < min(100, len(outData)); i++ { - if math.IsNaN(float64(outData[i])) || math.IsInf(float64(outData[i]), 0) { - t.Errorf("output[%d] not finite: %v", i, outData[i]) - break - } - } -} diff --git a/x/imagegen/models/qwen_image_edit/layers.go b/x/imagegen/models/qwen_image_edit/layers.go deleted file mode 100644 index 04c192077..000000000 --- a/x/imagegen/models/qwen_image_edit/layers.go +++ /dev/null @@ -1,682 +0,0 @@ -//go:build mlx - -package qwen_image_edit - -import ( - "fmt" - "math" - - "github.com/ollama/ollama/x/imagegen/mlx" - "github.com/ollama/ollama/x/imagegen/safetensors" -) - -// CausalConv3d is a causal 3D convolution (for temporal causality) -type CausalConv3d struct { - Weight *mlx.Array - Bias *mlx.Array - BiasReshaped *mlx.Array // [1, C, 1, 1, 1] - KernelT int32 -} - -// newCausalConv3d creates a 3D causal conv -func newCausalConv3d(weights *safetensors.ModelWeights, prefix string) (*CausalConv3d, error) { - weight, err := weights.Get(prefix + ".weight") - if err != nil { - return nil, fmt.Errorf("weight not found: %s", prefix) - } - bias, _ := weights.Get(prefix + ".bias") - - kernelT := weight.Shape()[2] - outC := weight.Shape()[0] - - var biasReshaped *mlx.Array - if bias != nil { - biasReshaped = mlx.Reshape(bias, 1, outC, 1, 1, 1) - } - - return &CausalConv3d{ - Weight: weight, - Bias: bias, - BiasReshaped: biasReshaped, - KernelT: kernelT, - }, nil -} - -// Forward applies causal 3D convolution (or 2D if weight is 4D) -// x: [B, T, H, W, C] (channels-last, MLX format) -func (c *CausalConv3d) Forward(x *mlx.Array) *mlx.Array { - shape := c.Weight.Shape() - - // Handle both 5D (3D conv) and 4D (2D conv) weights - if len(shape) == 4 { - // 2D conv: [O, I, kH, kW] - need to apply per-frame - return c.forward2D(x) - } - - // 3D conv: [O, I, kT, kH, kW] - kernelT := shape[2] - kernelH := shape[3] - kernelW := shape[4] - - // Causal temporal padding, same spatial padding - padT := kernelT - 1 - padH := kernelH / 2 - padW := kernelW / 2 - - // Stage 1: Pad - { - x = pad3DChannelsLast(x, padT, 0, padH, padH, padW, padW) - mlx.Eval(x) - } - - // Stage 2: Conv + bias - var out *mlx.Array - { - prev := x - weight := mlx.Transpose(c.Weight, 0, 2, 3, 4, 1) - out = mlx.Conv3d(x, weight, 1, 1, 1, 0, 0, 0) - if c.Bias != nil { - bias := mlx.Reshape(c.Bias, 1, 1, 1, 1, c.Bias.Dim(0)) - out = mlx.Add(out, bias) - } - prev.Free() - mlx.Eval(out) - } - - return out -} - -// forward2D applies 2D conv per-frame for [B, T, H, W, C] input -func (c *CausalConv3d) forward2D(x *mlx.Array) *mlx.Array { - xShape := x.Shape() - B := xShape[0] - T := xShape[1] - H := xShape[2] - W := xShape[3] - C := xShape[4] - - wShape := c.Weight.Shape() // [O, I, kH, kW] - kernelH := wShape[2] - kernelW := wShape[3] - outC := wShape[0] - - padH := kernelH / 2 - padW := kernelW / 2 - - // Reshape to [B*T, H, W, C] for 2D conv - x = mlx.Reshape(x, B*T, H, W, C) - - // Pad spatially - x = mlx.Pad(x, []int32{0, 0, padH, padH, padW, padW, 0, 0}) - - // Apply 2D conv - weight := mlx.Transpose(c.Weight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I] - x = mlx.Conv2d(x, weight, 1, 0) - - if c.Bias != nil { - bias := mlx.Reshape(c.Bias, 1, 1, 1, outC) - x = mlx.Add(x, bias) - } - - // Get output spatial dims - outH := H - outW := W - - // Reshape back to [B, T, H, W, C] - x = mlx.Reshape(x, B, T, outH, outW, outC) - mlx.Eval(x) - - return x -} - -// RMSNorm3D applies RMS normalization over channels -type RMSNorm3D struct { - Gamma *mlx.Array // [1, 1, 1, 1, C] for broadcasting -} - -// newRMSNorm3D creates an RMS norm -func newRMSNorm3D(weights *safetensors.ModelWeights, prefix string, dim int32) (*RMSNorm3D, error) { - gamma, err := weights.Get(prefix + ".gamma") - if err != nil { - return nil, err - } - gamma = mlx.Reshape(gamma, 1, 1, 1, 1, gamma.Dim(0)) - return &RMSNorm3D{Gamma: gamma}, nil -} - -// Forward applies RMS norm to channels-last input [B, T, H, W, C] -func (n *RMSNorm3D) Forward(x *mlx.Array) *mlx.Array { - normalized := mlx.RMSNormNoWeight(x, 1e-6) - return mlx.Mul(normalized, n.Gamma) -} - -// ResBlock is a residual block with RMS norm and causal convs -type ResBlock struct { - Norm1 *RMSNorm3D - Conv1 *CausalConv3d - Norm2 *RMSNorm3D - Conv2 *CausalConv3d - Shortcut *CausalConv3d -} - -// newResBlock creates a residual block -func newResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*ResBlock, error) { - norm1, err := newRMSNorm3D(weights, prefix+".norm1", inDim) - if err != nil { - return nil, err - } - conv1, err := newCausalConv3d(weights, prefix+".conv1") - if err != nil { - return nil, err - } - norm2, err := newRMSNorm3D(weights, prefix+".norm2", outDim) - if err != nil { - return nil, err - } - conv2, err := newCausalConv3d(weights, prefix+".conv2") - if err != nil { - return nil, err - } - - var shortcut *CausalConv3d - if inDim != outDim { - shortcut, err = newCausalConv3d(weights, prefix+".conv_shortcut") - if err != nil { - return nil, err - } - } - - return &ResBlock{ - Norm1: norm1, - Conv1: conv1, - Norm2: norm2, - Conv2: conv2, - Shortcut: shortcut, - }, nil -} - -// Forward applies the residual block -func (r *ResBlock) Forward(x *mlx.Array) *mlx.Array { - var h *mlx.Array - - mlx.Keep(x) - - // Stage 1: norm1 + silu - { - h = r.Norm1.Forward(x) - h = silu3D(h) - mlx.Eval(h) - } - - // Stage 2: conv1 - { - prev := h - h = r.Conv1.Forward(h) - prev.Free() - } - - // Stage 3: norm2 + silu - { - prev := h - h = r.Norm2.Forward(h) - h = silu3D(h) - prev.Free() - mlx.Eval(h) - } - - // Stage 4: conv2 - { - prev := h - h = r.Conv2.Forward(h) - prev.Free() - } - - // Residual connection - if r.Shortcut != nil { - shortcut := r.Shortcut.Forward(x) - h = mlx.Add(h, shortcut) - mlx.Eval(h) - } else { - h = mlx.Add(h, x) - mlx.Eval(h) - } - - return h -} - -// AttentionBlock is a 2D attention block -type AttentionBlock struct { - Norm *RMSNorm3D - ToQKV *mlx.Array - ToQKVBias *mlx.Array - Proj *mlx.Array - ProjBias *mlx.Array - Dim int32 -} - -// newAttentionBlock creates an attention block -func newAttentionBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*AttentionBlock, error) { - norm, err := newRMSNorm3D(weights, prefix+".norm", dim) - if err != nil { - return nil, err - } - toQKV, _ := weights.Get(prefix + ".to_qkv.weight") - toQKVBias, _ := weights.Get(prefix + ".to_qkv.bias") - proj, _ := weights.Get(prefix + ".proj.weight") - projBias, _ := weights.Get(prefix + ".proj.bias") - - return &AttentionBlock{ - Norm: norm, - ToQKV: toQKV, - ToQKVBias: toQKVBias, - Proj: proj, - ProjBias: projBias, - Dim: dim, - }, nil -} - -// Forward applies 2D attention -// Input: [B, T, H, W, C] (channels-last) -func (a *AttentionBlock) Forward(x *mlx.Array) *mlx.Array { - shape := x.Shape() - B := shape[0] - T := shape[1] - H := shape[2] - W := shape[3] - C := shape[4] - - identity := x - - // Flatten to [B*T, 1, H, W, C] for norm - x = mlx.Reshape(x, B*T, 1, H, W, C) - x = a.Norm.Forward(x) - x = mlx.Reshape(x, B*T, H, W, C) - - // Flatten spatial to [B*T, H*W, C] - x = mlx.Reshape(x, B*T, H*W, C) - - // Linear to get Q, K, V - wShape := a.ToQKV.Shape() - var w *mlx.Array - if len(wShape) == 4 { - w = mlx.Reshape(a.ToQKV, wShape[0], wShape[1]) - } else { - w = a.ToQKV - } - w = mlx.Transpose(w, 1, 0) - - qkv := mlx.Linear(x, w) - if a.ToQKVBias != nil { - qkv = mlx.Add(qkv, a.ToQKVBias) - } - qkv = mlx.Reshape(qkv, B*T, 1, H*W, 3*C) - - q := mlx.Slice(qkv, []int32{0, 0, 0, 0}, []int32{B * T, 1, H * W, C}) - k := mlx.Slice(qkv, []int32{0, 0, 0, C}, []int32{B * T, 1, H * W, 2 * C}) - v := mlx.Slice(qkv, []int32{0, 0, 0, 2 * C}, []int32{B * T, 1, H * W, 3 * C}) - - scale := float32(1.0 / math.Sqrt(float64(C))) - out := mlx.ScaledDotProductAttention(q, k, v, scale, false) - - out = mlx.Reshape(out, B*T, H*W, C) - - // Project back - pShape := a.Proj.Shape() - var p *mlx.Array - if len(pShape) == 4 { - p = mlx.Reshape(a.Proj, pShape[0], pShape[1]) - } else { - p = a.Proj - } - p = mlx.Transpose(p, 1, 0) - out = mlx.Linear(out, p) - if a.ProjBias != nil { - out = mlx.Add(out, a.ProjBias) - } - - out = mlx.Reshape(out, B, T, H, W, C) - return mlx.Add(out, identity) -} - -// UpBlock handles upsampling in decoder -type UpBlock struct { - ResBlocks []*ResBlock - Upsampler *Upsample -} - -// newUpBlock creates an up block -func newUpBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, upsampleMode string) (*UpBlock, error) { - resBlocks := make([]*ResBlock, numBlocks+1) - - currentDim := inDim - for i := int32(0); i <= numBlocks; i++ { - resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i) - block, err := newResBlock(weights, resPrefix, currentDim, outDim) - if err != nil { - return nil, err - } - resBlocks[i] = block - currentDim = outDim - } - - var upsampler *Upsample - if upsampleMode != "" { - upsampler = newUpsample(weights, prefix+".upsamplers.0", outDim, upsampleMode) - } - - return &UpBlock{ - ResBlocks: resBlocks, - Upsampler: upsampler, - }, nil -} - -// Forward applies up block -func (u *UpBlock) Forward(x *mlx.Array) *mlx.Array { - for _, block := range u.ResBlocks { - prev := x - x = block.Forward(x) - prev.Free() - } - - if u.Upsampler != nil { - prev := x - x = u.Upsampler.Forward(x) - prev.Free() - } - return x -} - -// Upsample handles spatial upsampling -type Upsample struct { - Conv *mlx.Array - Bias *mlx.Array - Mode string -} - -// newUpsample creates an upsampler -func newUpsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Upsample { - conv, _ := weights.Get(prefix + ".resample.1.weight") - bias, _ := weights.Get(prefix + ".resample.1.bias") - return &Upsample{ - Conv: conv, - Bias: bias, - Mode: mode, - } -} - -// Forward applies upsampling to channels-last input [B, T, H, W, C] -func (u *Upsample) Forward(x *mlx.Array) *mlx.Array { - shape := x.Shape() - B := shape[0] - T := shape[1] - H := shape[2] - W := shape[3] - C := shape[4] - outC := u.Conv.Shape()[0] - - // Stage 1: 2x nearest neighbor upsample - { - x = mlx.Reshape(x, B*T, H, W, C) - x = upsample2xChannelsLast(x) - mlx.Eval(x) - } - - // Stage 2: Conv + bias - { - prev := x - weight := mlx.Transpose(u.Conv, 0, 2, 3, 1) - x = conv2D3x3PaddedChannelsLast(x, weight) - if u.Bias != nil { - bias := mlx.Reshape(u.Bias, 1, 1, 1, outC) - x = mlx.Add(x, bias) - } - x = mlx.Reshape(x, B, T, H*2, W*2, outC) - prev.Free() - mlx.Eval(x) - } - - return x -} - -// MidBlock is the middle block -type MidBlock struct { - ResBlock1 *ResBlock - Attention *AttentionBlock - ResBlock2 *ResBlock -} - -// newMidBlock creates a mid block -func newMidBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*MidBlock, error) { - res1, err := newResBlock(weights, prefix+".resnets.0", dim, dim) - if err != nil { - return nil, err - } - attn, err := newAttentionBlock(weights, prefix+".attentions.0", dim) - if err != nil { - return nil, err - } - res2, err := newResBlock(weights, prefix+".resnets.1", dim, dim) - if err != nil { - return nil, err - } - - return &MidBlock{ - ResBlock1: res1, - Attention: attn, - ResBlock2: res2, - }, nil -} - -// Forward applies mid block -func (m *MidBlock) Forward(x *mlx.Array) *mlx.Array { - prev := x - x = m.ResBlock1.Forward(x) - prev.Free() - - prev = x - x = m.Attention.Forward(x) - prev.Free() - - prev = x - x = m.ResBlock2.Forward(x) - prev.Free() - - return x -} - -// Helper functions - -func silu3D(x *mlx.Array) *mlx.Array { - return mlx.Mul(x, mlx.Sigmoid(x)) -} - -// pad3DChannelsLast pads a channels-last [B, T, H, W, C] tensor -func pad3DChannelsLast(x *mlx.Array, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array { - if tBefore == 0 && tAfter == 0 && hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 { - return x - } - return mlx.Pad(x, []int32{0, 0, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter, 0, 0}) -} - -// upsample2xChannelsLast upsamples channels-last input [B, H, W, C] by 2x -func upsample2xChannelsLast(x *mlx.Array) *mlx.Array { - shape := x.Shape() - H := shape[1] - W := shape[2] - - rowIdxData := make([]int32, H*2) - for i := int32(0); i < H; i++ { - rowIdxData[i*2] = i - rowIdxData[i*2+1] = i - } - rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2}) - - colIdxData := make([]int32, W*2) - for i := int32(0); i < W; i++ { - colIdxData[i*2] = i - colIdxData[i*2+1] = i - } - colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2}) - - x = mlx.Take(x, rowIdx, 1) - x = mlx.Take(x, colIdx, 2) - - return x -} - -// conv2D3x3PaddedChannelsLast applies 3x3 conv with padding to channels-last input [B, H, W, C] -func conv2D3x3PaddedChannelsLast(x, weight *mlx.Array) *mlx.Array { - x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0}) - return mlx.Conv2d(x, weight, 1, 0) -} - -// conv2DStrided applies conv with stride > 1 using manual patch extraction -// x: [B, H, W, C] (channels-last), weight: [O, kH, kW, I] -func conv2DStrided(x, weight *mlx.Array, stride int32) *mlx.Array { - shape := x.Shape() - B := shape[0] - H := shape[1] - W := shape[2] - - wShape := weight.Shape() - Cout := wShape[0] - kH := wShape[1] - kW := wShape[2] - - outH := (H - kH) / stride + 1 - outW := (W - kW) / stride + 1 - - patches := extractPatches2DStrided(x, kH, kW, stride) - wFlat := mlx.Reshape(weight, Cout, -1) - patches = mlx.Reshape(patches, B*outH*outW, -1) - out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0)) - return mlx.Reshape(out, B, outH, outW, Cout) -} - -// conv3DStrided applies 3D conv with strides using manual patch extraction -// x: [B, T, H, W, C] (channels-last), weight: [O, I, kT, kH, kW] (PyTorch format) -// strideT, strideH, strideW are the strides for each dimension -// Patches are extracted in [C, T, H, W] order to match Python's preprocessing -func conv3DStrided(x, weight *mlx.Array, strideT, strideH, strideW int32) *mlx.Array { - shape := x.Shape() - B := shape[0] - T := shape[1] - H := shape[2] - W := shape[3] - C := shape[4] - - wShape := weight.Shape() - Cout := wShape[0] - // I := wShape[1] - kT := wShape[2] - kH := wShape[3] - kW := wShape[4] - - // For temporal: if T < kT, we need to repeat frames temporally - // For single image with T=1 and kT=2, we duplicate the frame to T=kT - // Python Qwen2.5-VL duplicates the frame, not zero-pads - if T < kT { - // Tile along T dimension: [B, T, H, W, C] -> [B, kT, H, W, C] - x = mlx.Tile(x, []int32{1, kT, 1, 1, 1}) - T = kT - } - - outT := (T - kT) / strideT + 1 - outH := (H - kH) / strideH + 1 - outW := (W - kW) / strideW + 1 - - // Extract 3D patches in [C, T, H, W] order to match Python - patches := extractPatches3DStrided(x, kT, kH, kW, strideT, strideH, strideW) - // patches shape: [B, outT, outH, outW, C*kT*kH*kW] - - // Weight is [O, I, kT, kH, kW] - flatten to [O, I*kT*kH*kW] to match patch order [C, T, H, W] - wFlat := mlx.Reshape(weight, Cout, -1) // [Cout, I*kT*kH*kW] - patches = mlx.Reshape(patches, B*outT*outH*outW, C*kT*kH*kW) - out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0)) - return mlx.Reshape(out, B, outT, outH, outW, Cout) -} - -// extractPatches3DStrided extracts 3D patches with given strides -// Returns patches with values in [C, T, H, W] order to match Python's preprocessing -func extractPatches3DStrided(x *mlx.Array, kT, kH, kW, strideT, strideH, strideW int32) *mlx.Array { - shape := x.Shape() - B := shape[0] - T := shape[1] - H := shape[2] - W := shape[3] - C := shape[4] - - outT := (T - kT) / strideT + 1 - outH := (H - kH) / strideH + 1 - outW := (W - kW) / strideW + 1 - - numPatches := outT * outH * outW - patches := make([]*mlx.Array, numPatches) - idx := 0 - for t := int32(0); t < outT; t++ { - for i := int32(0); i < outH; i++ { - for j := int32(0); j < outW; j++ { - startT := t * strideT - startH := i * strideH - startW := j * strideW - // Extract patch: [B, kT, kH, kW, C] - patch := mlx.Slice(x, - []int32{0, startT, startH, startW, 0}, - []int32{B, startT + kT, startH + kH, startW + kW, C}) - // Transpose from [B, T, H, W, C] to [B, C, T, H, W] to match Python's order - patch = mlx.Transpose(patch, 0, 4, 1, 2, 3) - // Flatten to [B, C*T*H*W] - patch = mlx.Reshape(patch, B, C*kT*kH*kW) - patches[idx] = patch - idx++ - } - } - } - - for i := range patches { - patches[i] = mlx.ExpandDims(patches[i], 1) - } - stacked := mlx.Concatenate(patches, 1) - return mlx.Reshape(stacked, B, outT, outH, outW, C*kT*kH*kW) -} - -// extractPatches2DStrided extracts patches with given stride -func extractPatches2DStrided(x *mlx.Array, kH, kW, stride int32) *mlx.Array { - shape := x.Shape() - B := shape[0] - H := shape[1] - W := shape[2] - C := shape[3] - - outH := (H - kH) / stride + 1 - outW := (W - kW) / stride + 1 - - patches := make([]*mlx.Array, outH*outW) - idx := 0 - for i := int32(0); i < outH; i++ { - for j := int32(0); j < outW; j++ { - startH := i * stride - startW := j * stride - patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C}) - patch = mlx.Reshape(patch, B, kH*kW*C) - patches[idx] = patch - idx++ - } - } - - for i := range patches { - patches[i] = mlx.ExpandDims(patches[i], 1) - } - stacked := mlx.Concatenate(patches, 1) - return mlx.Reshape(stacked, B, outH, outW, kH*kW*C) -} - -// layerNormNoAffine applies layer norm without learnable parameters -func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array { - ndim := x.Ndim() - lastAxis := ndim - 1 - mean := mlx.Mean(x, lastAxis, true) - xCentered := mlx.Sub(x, mean) - variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true) - return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps))) -} diff --git a/x/imagegen/models/qwen_image_edit/processor.go b/x/imagegen/models/qwen_image_edit/processor.go deleted file mode 100644 index c80f5a3b1..000000000 --- a/x/imagegen/models/qwen_image_edit/processor.go +++ /dev/null @@ -1,475 +0,0 @@ -//go:build mlx - -package qwen_image_edit - -import ( - "fmt" - "image" - "image/color" - _ "image/jpeg" - _ "image/png" - "math" - "os" - - "github.com/ollama/ollama/x/imagegen/mlx" - "golang.org/x/image/draw" - _ "golang.org/x/image/webp" -) - -// loadImageFile loads an image from disk -func loadImageFile(path string) (image.Image, error) { - f, err := os.Open(path) - if err != nil { - return nil, fmt.Errorf("open image: %w", err) - } - defer f.Close() - - img, _, err := image.Decode(f) - if err != nil { - return nil, fmt.Errorf("decode image: %w", err) - } - return img, nil -} - -// imageToFloat32Pixels converts an image to a float32 pixel array [H, W, C] in [0, 1] range -func imageToFloat32Pixels(img image.Image, width, height int) []float32 { - pixels := make([]float32, width*height*3) - idx := 0 - for y := 0; y < height; y++ { - for x := 0; x < width; x++ { - r, g, b, _ := img.At(x, y).RGBA() - pixels[idx] = float32(r) / 65535.0 - pixels[idx+1] = float32(g) / 65535.0 - pixels[idx+2] = float32(b) / 65535.0 - idx += 3 - } - } - return pixels -} - -// normalizeImageNet applies ImageNet normalization to an image tensor -func (p *Processor) normalizeImageNet(arr *mlx.Array) *mlx.Array { - mean := mlx.NewArray(p.Config.ImageMean, []int32{1, 1, 3}) - std := mlx.NewArray(p.Config.ImageStd, []int32{1, 1, 3}) - return mlx.Div(mlx.Sub(arr, mean), std) -} - -// prepareImageTensor transforms [H, W, C] to [B, C, H, W] and converts to bf16 -func prepareImageTensor(arr *mlx.Array) *mlx.Array { - // Transpose to [C, H, W] and make contiguous - arr = mlx.Contiguous(mlx.Transpose(arr, 2, 0, 1)) - // Add batch dimension [1, C, H, W] - arr = mlx.ExpandDims(arr, 0) - // Convert to bf16 - arr = mlx.ToBFloat16(arr) - mlx.Eval(arr) - return arr -} - -// clampFloat clamps a value to [0, 255] and returns uint8 -func clampFloat(v, weightSum float64) uint8 { - v /= weightSum - if v < 0 { - v = 0 - } - if v > 255 { - v = 255 - } - return uint8(math.Round(v)) -} - -// ImageDims holds dimensions for a preprocessed image -type ImageDims struct { - // Original image dimensions - OrigW, OrigH int32 - // Condition image dimensions (for vision encoder) - CondW, CondH int32 - // VAE image dimensions - VaeW, VaeH int32 - // Latent dimensions (VAE dims / vae_scale_factor) - LatentW, LatentH int32 - // Patch dimensions (latent dims / patch_size) - PatchW, PatchH int32 -} - -// ProcessorConfig holds image processor configuration -type ProcessorConfig struct { - // Condition image size (target pixel area for vision encoder input) - // Python: CONDITION_IMAGE_SIZE = 384 * 384 = 147456 - // Pipeline resizes image to this area before passing to encode_prompt - ConditionImageSize int32 - - // VAE image size (target pixel area) - // Python: VAE_IMAGE_SIZE = 1024 * 1024 = 1048576 - VAEImageSize int32 - - // Image normalization (ImageNet stats for vision encoder) - ImageMean []float32 - ImageStd []float32 -} - -// defaultProcessorConfig returns default processor config -func defaultProcessorConfig() *ProcessorConfig { - return &ProcessorConfig{ - ConditionImageSize: 384 * 384, // 147456 - matches Python CONDITION_IMAGE_SIZE - VAEImageSize: 1024 * 1024, // 1048576 - matches Python VAE_IMAGE_SIZE - ImageMean: []float32{0.48145466, 0.4578275, 0.40821073}, - ImageStd: []float32{0.26862954, 0.26130258, 0.27577711}, - } -} - -// Processor handles image preprocessing for Qwen-Image-Edit -type Processor struct { - Config *ProcessorConfig -} - -// Load loads the processor config -func (p *Processor) Load(path string) error { - p.Config = defaultProcessorConfig() - return nil -} - -// LoadAndPreprocess loads an image and preprocesses it for both paths -// Returns: condImage (for vision encoder), vaeImage (for VAE encoding) -func (p *Processor) LoadAndPreprocess(imagePath string) (*mlx.Array, *mlx.Array, error) { - img, err := loadImageFile(imagePath) - if err != nil { - return nil, nil, err - } - - bounds := img.Bounds() - origW := bounds.Dx() - origH := bounds.Dy() - ratio := float64(origW) / float64(origH) - - // Calculate dimensions for condition image (vision encoder) - // Python pipeline does TWO resizes: - // 1. VaeImageProcessor.resize with Lanczos to CONDITION_IMAGE_SIZE (384x384 area) - // 2. Qwen2VLProcessor's smart_resize with Bicubic to multiple of 28 - intermediateW, intermediateH := calculateDimensions(p.Config.ConditionImageSize, ratio, 32) - finalH, finalW := smartResize(intermediateH, intermediateW, 28, 56*56, 28*28*1280) - - // Calculate dimensions for VAE image (1024x1024 area) - // Use multiple of 32 (vae_scale_factor * patch_size * 2 = 8 * 2 * 2 = 32) - vaeW, vaeH := calculateDimensions(p.Config.VAEImageSize, ratio, 32) - - // Preprocess for condition (vision encoder) - two-step resize - condImage := p.preprocessImageTwoStep(img, intermediateW, intermediateH, finalW, finalH) - - // Preprocess for VAE ([-1, 1] range, 5D tensor) - vaeImage := p.preprocessImageForVAE(img, vaeW, vaeH) - - return condImage, vaeImage, nil -} - -// preprocessImageLanczos does single-step Lanczos resize for vision encoder -// Matches Python VaeImageProcessor.resize with resample='lanczos' (the default) -// Used by edit_plus pipeline for multi-image input -// Returns: [B, C, H, W] normalized tensor -func (p *Processor) preprocessImageLanczos(img image.Image, width, height int32) *mlx.Array { - resized := resizeImageLanczos(img, int(width), int(height)) - pixels := imageToFloat32Pixels(resized, int(width), int(height)) - arr := mlx.NewArray(pixels, []int32{height, width, 3}) - arr = p.normalizeImageNet(arr) - return prepareImageTensor(arr) -} - -// preprocessImageTwoStep does two-step resize for vision encoder to match Python pipeline -// Step 1: Lanczos resize from original to intermediate size (VaeImageProcessor.resize) -// Step 2: Bicubic resize from intermediate to final size (Qwen2VLProcessor smart_resize) -// Returns: [B, C, H, W] normalized tensor -func (p *Processor) preprocessImageTwoStep(img image.Image, intermediateW, intermediateH, finalW, finalH int32) *mlx.Array { - intermediate := resizeImageLanczos(img, int(intermediateW), int(intermediateH)) - resized := resizeImageBicubic(intermediate, int(finalW), int(finalH)) - pixels := imageToFloat32Pixels(resized, int(finalW), int(finalH)) - arr := mlx.NewArray(pixels, []int32{finalH, finalW, 3}) - arr = p.normalizeImageNet(arr) - return prepareImageTensor(arr) -} - -// preprocessImage converts image to tensor for vision encoder -// Returns: [B, C, H, W] normalized tensor -func (p *Processor) preprocessImage(img image.Image, width, height int32, normalize bool) *mlx.Array { - resized := resizeImageBicubic(img, int(width), int(height)) - pixels := imageToFloat32Pixels(resized, int(width), int(height)) - arr := mlx.NewArray(pixels, []int32{height, width, 3}) - if normalize { - arr = p.normalizeImageNet(arr) - } - return prepareImageTensor(arr) -} - -// preprocessImageForVAE converts image to tensor for VAE encoding -// Returns: [B, C, T, H, W] tensor in [-1, 1] range -func (p *Processor) preprocessImageForVAE(img image.Image, width, height int32) *mlx.Array { - resized := resizeImageLanczos(img, int(width), int(height)) - pixels := imageToFloat32Pixels(resized, int(width), int(height)) - arr := mlx.NewArray(pixels, []int32{height, width, 3}) - - // Scale to [-1, 1]: arr * 2 - 1 - arr = mlx.MulScalar(arr, 2.0) - arr = mlx.AddScalar(arr, -1.0) - - // Transpose to [C, H, W] and make contiguous - arr = mlx.Contiguous(mlx.Transpose(arr, 2, 0, 1)) - - // Add batch and temporal dimensions [1, C, 1, H, W] - arr = mlx.ExpandDims(arr, 0) // [1, C, H, W] - arr = mlx.ExpandDims(arr, 2) // [1, C, 1, H, W] - - arr = mlx.ToBFloat16(arr) - mlx.Eval(arr) - return arr -} - -// smartResize implements Python Qwen2VL processor's smart_resize logic -// Returns (resizedHeight, resizedWidth) that fit within min/max pixel constraints -func smartResize(height, width, factor, minPixels, maxPixels int32) (int32, int32) { - // Round to factor - hBar := int32(math.Round(float64(height)/float64(factor))) * factor - wBar := int32(math.Round(float64(width)/float64(factor))) * factor - - // Ensure minimum factor size - if hBar < factor { - hBar = factor - } - if wBar < factor { - wBar = factor - } - - // Check pixel constraints - total := hBar * wBar - if total > maxPixels { - // Scale down - beta := math.Sqrt(float64(maxPixels) / float64(total)) - hBar = int32(math.Floor(float64(height)*beta/float64(factor))) * factor - wBar = int32(math.Floor(float64(width)*beta/float64(factor))) * factor - } else if total < minPixels { - // Scale up - beta := math.Sqrt(float64(minPixels) / float64(total)) - hBar = int32(math.Ceil(float64(height)*beta/float64(factor))) * factor - wBar = int32(math.Ceil(float64(width)*beta/float64(factor))) * factor - } - - return hBar, wBar -} - -// calculateDimensions calculates width and height for a target area while maintaining ratio -// multiple: the value to round dimensions to (e.g., 28 for vision encoder with patch 14 and 2x2 merge) -func calculateDimensions(targetArea int32, ratio float64, multiple int32) (int32, int32) { - width := math.Sqrt(float64(targetArea) * ratio) - height := width / ratio - - m := float64(multiple) - width = math.Round(width/m) * m - height = math.Round(height/m) * m - - // Ensure minimum dimensions - if width < m { - width = m - } - if height < m { - height = m - } - - return int32(width), int32(height) -} - -// resizeImageLanczos resizes an image using Lanczos3 interpolation (matches PIL.LANCZOS) -func resizeImageLanczos(img image.Image, width, height int) image.Image { - bounds := img.Bounds() - dst := image.NewRGBA(image.Rect(0, 0, width, height)) - - // Lanczos3 kernel (a=3) to match PIL.LANCZOS - lanczos3 := &draw.Kernel{ - Support: 3.0, - At: func(t float64) float64 { - if t == 0 { - return 1.0 - } - if t < 0 { - t = -t - } - if t >= 3.0 { - return 0.0 - } - // sinc(t) * sinc(t/3) - piT := math.Pi * t - return (math.Sin(piT) / piT) * (math.Sin(piT/3) / (piT / 3)) - }, - } - lanczos3.Scale(dst, dst.Bounds(), img, bounds, draw.Over, nil) - - return dst -} - -// resizeImageBicubic resizes an image using bicubic interpolation (matches PIL.BICUBIC) -// Uses separable interpolation with PIL's coordinate mapping for exact match -func resizeImageBicubic(img image.Image, width, height int) image.Image { - bounds := img.Bounds() - srcW := bounds.Dx() - srcH := bounds.Dy() - - // Convert to RGBA if needed - var src *image.RGBA - if rgba, ok := img.(*image.RGBA); ok { - src = rgba - } else { - src = image.NewRGBA(bounds) - for y := bounds.Min.Y; y < bounds.Max.Y; y++ { - for x := bounds.Min.X; x < bounds.Max.X; x++ { - src.Set(x, y, img.At(x, y)) - } - } - } - - // Keys cubic with a=-0.5 (PIL BICUBIC) - cubic := func(x float64) float64 { - if x < 0 { - x = -x - } - if x < 1 { - return 1.5*x*x*x - 2.5*x*x + 1 - } - if x < 2 { - return -0.5*x*x*x + 2.5*x*x - 4*x + 2 - } - return 0 - } - - // Horizontal pass: srcW -> width, keep srcH rows - temp := image.NewRGBA(image.Rect(0, 0, width, srcH)) - for y := 0; y < srcH; y++ { - for dstX := 0; dstX < width; dstX++ { - // PIL coordinate mapping: center-to-center - srcXf := (float64(dstX)+0.5)*(float64(srcW)/float64(width)) - 0.5 - baseX := int(math.Floor(srcXf)) - - var sumR, sumG, sumB, sumA, weightSum float64 - for i := -1; i <= 2; i++ { - sx := baseX + i - if sx < 0 { - sx = 0 - } - if sx >= srcW { - sx = srcW - 1 - } - - w := cubic(math.Abs(srcXf - float64(baseX+i))) - c := src.RGBAAt(sx, y) - sumR += float64(c.R) * w - sumG += float64(c.G) * w - sumB += float64(c.B) * w - sumA += float64(c.A) * w - weightSum += w - } - - temp.SetRGBA(dstX, y, color.RGBA{ - clampFloat(sumR, weightSum), - clampFloat(sumG, weightSum), - clampFloat(sumB, weightSum), - clampFloat(sumA, weightSum), - }) - } - } - - // Vertical pass: srcH -> height - dst := image.NewRGBA(image.Rect(0, 0, width, height)) - for x := 0; x < width; x++ { - for dstY := 0; dstY < height; dstY++ { - srcYf := (float64(dstY)+0.5)*(float64(srcH)/float64(height)) - 0.5 - baseY := int(math.Floor(srcYf)) - - var sumR, sumG, sumB, sumA, weightSum float64 - for j := -1; j <= 2; j++ { - sy := baseY + j - if sy < 0 { - sy = 0 - } - if sy >= srcH { - sy = srcH - 1 - } - - w := cubic(math.Abs(srcYf - float64(baseY+j))) - c := temp.RGBAAt(x, sy) - sumR += float64(c.R) * w - sumG += float64(c.G) * w - sumB += float64(c.B) * w - sumA += float64(c.A) * w - weightSum += w - } - - dst.SetRGBA(x, dstY, color.RGBA{ - clampFloat(sumR, weightSum), - clampFloat(sumG, weightSum), - clampFloat(sumB, weightSum), - clampFloat(sumA, weightSum), - }) - } - } - - return dst -} - -// LoadAndPreprocessMultiple loads multiple images and preprocesses them -// Returns: condImages (for vision encoder), vaeImages (for VAE encoding), dims (per-image dimensions) -func (p *Processor) LoadAndPreprocessMultiple(imagePaths []string) ([]*mlx.Array, []*mlx.Array, []ImageDims, error) { - const vaeScaleFactor int32 = 8 - const patchSize int32 = 2 - - condImages := make([]*mlx.Array, len(imagePaths)) - vaeImages := make([]*mlx.Array, len(imagePaths)) - dims := make([]ImageDims, len(imagePaths)) - - for i, imagePath := range imagePaths { - img, err := loadImageFile(imagePath) - if err != nil { - return nil, nil, nil, fmt.Errorf("image %d: %w", i, err) - } - - bounds := img.Bounds() - origW := int32(bounds.Dx()) - origH := int32(bounds.Dy()) - ratio := float64(origW) / float64(origH) - - // Calculate dimensions for condition image (vision encoder) - // Python pipeline does TWO resizes: - // 1. VaeImageProcessor.resize with Lanczos to CONDITION_IMAGE_SIZE (384x384 area) - // 2. Qwen2VLProcessor's smart_resize with Bicubic to multiple of 28 - intermediateW, intermediateH := calculateDimensions(p.Config.ConditionImageSize, ratio, 32) - condH, condW := smartResize(intermediateH, intermediateW, 28, 56*56, 28*28*1280) - - // Calculate dimensions for VAE image (1024x1024 area) - vaeW, vaeH := calculateDimensions(p.Config.VAEImageSize, ratio, 32) - - // Calculate derived dimensions - latentW := vaeW / vaeScaleFactor - latentH := vaeH / vaeScaleFactor - patchW := latentW / patchSize - patchH := latentH / patchSize - - dims[i] = ImageDims{ - OrigW: origW, - OrigH: origH, - CondW: condW, - CondH: condH, - VaeW: vaeW, - VaeH: vaeH, - LatentW: latentW, - LatentH: latentH, - PatchW: patchW, - PatchH: patchH, - } - - fmt.Printf(" Image %d: orig=%dx%d, cond=%dx%d, vae=%dx%d, latent=%dx%d, patch=%dx%d\n", - i+1, origW, origH, condW, condH, vaeW, vaeH, latentW, latentH, patchW, patchH) - - // Preprocess for condition (vision encoder) - two-step resize to match Python pipeline - condImages[i] = p.preprocessImageTwoStep(img, intermediateW, intermediateH, condW, condH) - - // Preprocess for VAE ([-1, 1] range, 5D tensor) - vaeImages[i] = p.preprocessImageForVAE(img, vaeW, vaeH) - } - - return condImages, vaeImages, dims, nil -} diff --git a/x/imagegen/models/qwen_image_edit/qwen_image_edit.go b/x/imagegen/models/qwen_image_edit/qwen_image_edit.go deleted file mode 100644 index d1e394986..000000000 --- a/x/imagegen/models/qwen_image_edit/qwen_image_edit.go +++ /dev/null @@ -1,625 +0,0 @@ -//go:build mlx - -// Package qwen_image_edit implements the Qwen-Image-Edit diffusion model for image editing. -// It reuses components from qwen_image where possible. -package qwen_image_edit - -import ( - "context" - "fmt" - "path/filepath" - "time" - - "github.com/ollama/ollama/x/imagegen/mlx" - "github.com/ollama/ollama/x/imagegen/models/qwen_image" - "github.com/ollama/ollama/x/imagegen/tokenizer" -) - -// GenerateConfig holds all options for image editing. -type GenerateConfig struct { - Prompt string - NegativePrompt string // Unconditional prompt for CFG (empty string "" is valid) - CFGScale float32 // CFG enabled when > 1.0 (default: 4.0) - Width int32 // Output width (default: from input image) - Height int32 // Output height (default: from input image) - Steps int // Denoising steps (default: 50) - Seed int64 // Random seed - Progress func(step, totalSteps int) // Optional progress callback -} - -// Model represents a Qwen-Image-Edit diffusion model. -type Model struct { - ModelPath string - Tokenizer *tokenizer.Tokenizer - Processor *Processor // Image processor for vision encoder - TextEncoder *qwen_image.Qwen25VL // Qwen2.5-VL vision-language encoder (from qwen_image) - Transformer *qwen_image.Transformer // Reuse qwen_image transformer - VAE *VAE // Combined encoder + decoder -} - -// Load loads the Qwen-Image-Edit model from a directory. -func (m *Model) Load(modelPath string) error { - fmt.Println("Loading Qwen-Image-Edit model...") - start := time.Now() - - if mlx.GPUIsAvailable() { - mlx.SetDefaultDeviceGPU() - mlx.EnableCompile() - } - - m.ModelPath = modelPath - - // Load tokenizer from processor directory - fmt.Print(" Loading tokenizer... ") - processorPath := filepath.Join(modelPath, "processor") - tok, err := tokenizer.Load(processorPath) - if err != nil { - // Fallback to tokenizer directory - tokenizerPath := filepath.Join(modelPath, "tokenizer") - tok, err = tokenizer.Load(tokenizerPath) - if err != nil { - return fmt.Errorf("tokenizer: %w", err) - } - } - m.Tokenizer = tok - fmt.Println("✓") - - // Load processor (image preprocessing config) - fmt.Print(" Loading processor... ") - m.Processor = &Processor{} - if err := m.Processor.Load(processorPath); err != nil { - return fmt.Errorf("processor: %w", err) - } - fmt.Println("✓") - - // Load vision-language text encoder (Qwen2.5-VL from qwen_image package) - m.TextEncoder = &qwen_image.Qwen25VL{} - if err := m.TextEncoder.Load(filepath.Join(modelPath, "text_encoder")); err != nil { - return fmt.Errorf("text encoder: %w", err) - } - mlx.Eval(mlx.Collect(m.TextEncoder)...) - fmt.Printf(" (%.1f GB, peak %.1f GB)\n", - float64(mlx.MetalGetActiveMemory())/(1024*1024*1024), - float64(mlx.MetalGetPeakMemory())/(1024*1024*1024)) - - // Load transformer (reuse qwen_image) - m.Transformer = &qwen_image.Transformer{} - if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil { - return fmt.Errorf("transformer: %w", err) - } - mlx.Eval(mlx.Collect(m.Transformer)...) - fmt.Printf(" (%.1f GB, peak %.1f GB)\n", - float64(mlx.MetalGetActiveMemory())/(1024*1024*1024), - float64(mlx.MetalGetPeakMemory())/(1024*1024*1024)) - - // Load VAE (encoder + decoder) - m.VAE = &VAE{} - if err := m.VAE.Load(filepath.Join(modelPath, "vae")); err != nil { - return fmt.Errorf("VAE: %w", err) - } - mlx.Eval(mlx.Collect(m.VAE)...) - fmt.Printf(" (%.1f GB, peak %.1f GB)\n", - float64(mlx.MetalGetActiveMemory())/(1024*1024*1024), - float64(mlx.MetalGetPeakMemory())/(1024*1024*1024)) - - mem := mlx.MetalGetActiveMemory() - peak := mlx.MetalGetPeakMemory() - fmt.Printf(" Loaded in %.2fs (%.1f GB active, %.1f GB peak)\n", - time.Since(start).Seconds(), - float64(mem)/(1024*1024*1024), - float64(peak)/(1024*1024*1024)) - - return nil -} - -// Edit edits an image based on a text prompt. -// inputImagePath: path to input image -// prompt: text description of desired edit -func (m *Model) Edit(inputImagePath string, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) { - return m.EditFromConfig([]string{inputImagePath}, &GenerateConfig{ - Prompt: prompt, - Width: width, - Height: height, - Steps: steps, - Seed: seed, - }) -} - -// EditFromConfig edits images using the unified config struct. -// Accepts one or more input images. -func (m *Model) EditFromConfig(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) { - if len(inputImagePaths) == 0 { - return nil, fmt.Errorf("no input images provided") - } - - start := time.Now() - result, err := m.edit(inputImagePaths, cfg) - if err != nil { - return nil, err - } - - if cfg.NegativePrompt != "" { - fmt.Printf("Edited %d image(s) with CFG (scale=%.1f) in %.2fs (%d steps)\n", - len(inputImagePaths), cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps) - } else { - fmt.Printf("Edited %d image(s) in %.2fs (%d steps)\n", - len(inputImagePaths), time.Since(start).Seconds(), cfg.Steps) - } - return result, nil -} - -// EditImage implements model.ImageEditModel interface. -func (m *Model) EditImage(ctx context.Context, inputImagePath, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) { - return m.Edit(inputImagePath, prompt, width, height, steps, seed) -} - -// EditMultiImage edits using multiple source images. -// This matches diffusers' QwenImageEditPlusPipeline behavior. -func (m *Model) EditMultiImage(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) { - return m.EditFromConfig(inputImagePaths, cfg) -} - -// edit is the internal editing pipeline that handles one or more images. -func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) { - // Apply defaults - if cfg.Steps <= 0 { - cfg.Steps = 50 - } - if cfg.CFGScale <= 0 { - cfg.CFGScale = 4.0 - } - - // Load and preprocess all input images - fmt.Printf("Loading %d image(s)...\n", len(inputImagePaths)) - condImages, vaeImages, inputDims, err := m.Processor.LoadAndPreprocessMultiple(inputImagePaths) - if err != nil { - return nil, fmt.Errorf("preprocess images: %w", err) - } - for _, img := range condImages { - mlx.Keep(img) - } - for _, img := range vaeImages { - mlx.Keep(img) - } - mlx.Eval(append(condImages, vaeImages...)...) - - useCFG := cfg.NegativePrompt != "" - tcfg := m.Transformer.Config - vaeScaleFactor := int32(8) - - // Output dimensions - if not specified, use first input image dimensions - if cfg.Width <= 0 { - cfg.Width = inputDims[0].VaeW - } - if cfg.Height <= 0 { - cfg.Height = inputDims[0].VaeH - } - - // Output (noise) latent dimensions - outLatentH := cfg.Height / vaeScaleFactor - outLatentW := cfg.Width / vaeScaleFactor - outPH := outLatentH / tcfg.PatchSize - outPW := outLatentW / tcfg.PatchSize - noiseSeqLen := outPH * outPW - imgSeqLen := noiseSeqLen - - // Encode prompt with all images for conditioning - posEmb, _, _, err := m.TextEncoder.EncodePromptWithImages(m.Tokenizer, cfg.Prompt, condImages) - if err != nil { - return nil, fmt.Errorf("encoding prompt: %w", err) - } - mlx.Keep(posEmb) - mlx.Eval(posEmb) - - var negEmb *mlx.Array - if useCFG { - negEmb, _, _, err = m.TextEncoder.EncodePromptWithImages(m.Tokenizer, cfg.NegativePrompt, condImages) - if err != nil { - return nil, fmt.Errorf("encoding negative prompt: %w", err) - } - mlx.Keep(negEmb) - mlx.Eval(negEmb) - } - - // Pad sequences to same length for CFG - txtLen := posEmb.Shape()[1] - if useCFG { - negLen := negEmb.Shape()[1] - if negLen > txtLen { - txtLen = negLen - } - if posEmb.Shape()[1] < txtLen { - posEmb = padSequence(posEmb, txtLen) - } - if negEmb.Shape()[1] < txtLen { - negEmb = padSequence(negEmb, txtLen) - } - mlx.Keep(posEmb, negEmb) - mlx.Eval(posEmb, negEmb) - } - - // Pre-compute batched embeddings for CFG (single forward pass optimization) - var batchedEmb *mlx.Array - if useCFG { - batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0) - mlx.Keep(batchedEmb) - mlx.Eval(batchedEmb) - } - - // Encode all input images to latents and concatenate - fmt.Println("Encoding images to latents...") - allImageLatentsPacked := make([]*mlx.Array, len(vaeImages)) - for i, vaeImage := range vaeImages { - imageLatents := m.VAE.Encode(vaeImage) - imageLatents = m.VAE.Normalize(imageLatents) - imageLatents2D := mlx.Squeeze(imageLatents, 2) - packed := qwen_image.PackLatents(imageLatents2D, tcfg.PatchSize) - mlx.Keep(packed) - mlx.Eval(packed) - allImageLatentsPacked[i] = packed - } - - imageLatentsPacked := mlx.Concatenate(allImageLatentsPacked, 1) - mlx.Keep(imageLatentsPacked) - mlx.Eval(imageLatentsPacked) - - // Scheduler - scheduler := qwen_image.NewFlowMatchScheduler(qwen_image.DefaultSchedulerConfig()) - scheduler.SetTimesteps(cfg.Steps, noiseSeqLen) - - // Init noise latents in packed format - packedChannels := tcfg.OutChannels * tcfg.PatchSize * tcfg.PatchSize - packedNoise := scheduler.InitNoisePacked(1, noiseSeqLen, packedChannels, cfg.Seed) - latents := qwen_image.UnpackLatents(packedNoise, outLatentH, outLatentW, tcfg.PatchSize) - mlx.Eval(latents) - - // RoPE cache - ropeCache := PrepareRoPEMultiImage(outPH, outPW, inputDims, txtLen, tcfg.AxesDimsRope) - mlx.Keep(ropeCache.ImgFreqs, ropeCache.TxtFreqs) - mlx.Eval(ropeCache.ImgFreqs, ropeCache.TxtFreqs) - - // Denoising loop - fmt.Printf("Running denoising (%d steps)...\n", cfg.Steps) - for i := 0; i < cfg.Steps; i++ { - stepStart := time.Now() - if cfg.Progress != nil { - cfg.Progress(i+1, cfg.Steps) - } - - t := scheduler.Timesteps[i] - timestep := mlx.ToBFloat16(mlx.NewArray([]float32{t}, []int32{1})) - mlx.Eval(timestep) - - latents2D := mlx.Squeeze(latents, 2) - patches := qwen_image.PackLatents(latents2D, tcfg.PatchSize) - latentInput := mlx.Concatenate([]*mlx.Array{patches, imageLatentsPacked}, 1) - - var output *mlx.Array - if useCFG { - // CFG Batching: single forward pass with batch=2 - // Tile inputs: [1, L, D] -> [2, L, D] - batchedLatentInput := mlx.Tile(latentInput, []int32{2, 1, 1}) - batchedTimestep := mlx.Tile(timestep, []int32{2}) - - // Single batched forward pass - batchedOutput := m.Transformer.Forward(batchedLatentInput, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs) - - // Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D] - D := batchedOutput.Shape()[2] - posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, D}) - negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, imgSeqLen, D}) - - output = applyCFGWithNormRescale(posOutput, negOutput, cfg.CFGScale) - } else { - output = m.Transformer.Forward(latentInput, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs) - output = mlx.Slice(output, []int32{0, 0, 0}, []int32{1, imgSeqLen, output.Shape()[2]}) - } - - noisePred := qwen_image.UnpackLatents(output, outLatentH, outLatentW, tcfg.PatchSize) - oldLatents := latents - latents = scheduler.Step(noisePred, latents, i) - mlx.Eval(latents) - oldLatents.Free() - - fmt.Printf(" Step %d/%d: t=%.4f (%.2fs)\n", i+1, cfg.Steps, t, time.Since(stepStart).Seconds()) - } - - // Free denoising temporaries - posEmb.Free() - if negEmb != nil { - negEmb.Free() - } - if batchedEmb != nil { - batchedEmb.Free() - } - ropeCache.ImgFreqs.Free() - ropeCache.TxtFreqs.Free() - imageLatentsPacked.Free() - - // Decode latents - decoded := m.decodeAndPostprocess(latents) - latents.Free() - - fmt.Printf(" Peak memory: %.2f GB\n", float64(mlx.MetalGetPeakMemory())/(1024*1024*1024)) - return decoded, nil -} - -// applyCFGWithNormRescale applies classifier-free guidance with norm rescaling. -// This prevents CFG from inflating magnitude too much. -func applyCFGWithNormRescale(posOutput, negOutput *mlx.Array, scale float32) *mlx.Array { - // Upcast to float32 for precision - posF32 := mlx.AsType(posOutput, mlx.DtypeFloat32) - negF32 := mlx.AsType(negOutput, mlx.DtypeFloat32) - - // CFG: pred = neg + scale * (pos - neg) - diff := mlx.Sub(posF32, negF32) - scaledDiff := mlx.MulScalar(diff, scale) - combPred := mlx.Add(negF32, scaledDiff) - - // Norm rescaling: rescale combined prediction to match conditional norm - condNorm := mlx.Sqrt(mlx.Sum(mlx.Square(posF32), -1, true)) - combNorm := mlx.Sqrt(mlx.Sum(mlx.Square(combPred), -1, true)) - output := mlx.Mul(combPred, mlx.Div(condNorm, combNorm)) - - mlx.Eval(output) - return mlx.ToBFloat16(output) -} - -// decodeAndPostprocess denormalizes latents, decodes through VAE, and scales to [0,1]. -func (m *Model) decodeAndPostprocess(latents *mlx.Array) *mlx.Array { - latents = m.VAE.Denormalize(latents) - decoded := m.VAE.Decode(latents) - - // Post-process: squeeze temporal dim and rescale to [0, 1] - decoded = mlx.Squeeze(decoded, 2) - decoded = mlx.AddScalar(decoded, 1.0) - decoded = mlx.DivScalar(decoded, 2.0) - decoded = mlx.ClipScalar(decoded, 0.0, 1.0, true, true) - mlx.Eval(decoded) - return decoded -} - -// padSequence pads a sequence tensor to the target length with zeros -func padSequence(x *mlx.Array, targetLen int32) *mlx.Array { - shape := x.Shape() - currentLen := shape[1] - if currentLen >= targetLen { - return x - } - padLen := targetLen - currentLen - // Pad on sequence dimension (axis 1) - return mlx.Pad(x, []int32{0, 0, 0, padLen, 0, 0}) -} - -// LoadPersistent is an alias for backward compatibility. -func LoadPersistent(modelPath string) (*Model, error) { - m := &Model{} - if err := m.Load(modelPath); err != nil { - return nil, err - } - return m, nil -} - -// PrepareRoPEMultiImage computes RoPE with interpolation for image editing. -// Handles single or multiple input images with different resolutions. -// -// Parameters: -// - outPH, outPW: output patch dimensions (noise latent resolution) -// - inputDims: patch dimensions for each input image [(pH1, pW1), (pH2, pW2), ...] -// - txtLen: text sequence length -// - axesDims: RoPE axis dimensions [16, 56, 56] -// -// Returns RoPE cache where: -// - ImgFreqs has (outPH*outPW + sum(inPH*inPW for each image)) positions -// - First outPH*outPW positions are for noise latents (standard RoPE at output res) -// - Following positions are for each input image (interpolated from output res) -func PrepareRoPEMultiImage(outPH, outPW int32, inputDims []ImageDims, txtLen int32, axesDims []int32) *qwen_image.RoPECache { - theta := float64(10000) - maxIdx := int32(4096) - - // Compute base frequencies for each axis dimension - freqsT := qwen_image.ComputeAxisFreqs(axesDims[0], theta) - freqsH := qwen_image.ComputeAxisFreqs(axesDims[1], theta) - freqsW := qwen_image.ComputeAxisFreqs(axesDims[2], theta) - - // Build frequency lookup tables - posFreqsT := qwen_image.MakeFreqTable(maxIdx, freqsT, false) - posFreqsH := qwen_image.MakeFreqTable(maxIdx, freqsH, false) - posFreqsW := qwen_image.MakeFreqTable(maxIdx, freqsW, false) - negFreqsT := qwen_image.MakeFreqTable(maxIdx, freqsT, true) // For frame -1 on last condition image - negFreqsH := qwen_image.MakeFreqTable(maxIdx, freqsH, true) - negFreqsW := qwen_image.MakeFreqTable(maxIdx, freqsW, true) - - headDim := int32(len(freqsT)+len(freqsH)+len(freqsW)) * 2 - - // Helper to compute RoPE for a single position at output resolution with scale_rope - computePosFreqs := func(framePos, y, x int32) []float32 { - row := make([]float32, headDim) - idx := 0 - - // Frame position - for i := 0; i < len(freqsT)*2; i++ { - row[idx+i] = posFreqsT[framePos][i] - } - idx += len(freqsT) * 2 - - // Height with scale_rope centering (using OUTPUT dimensions) - outHHalf := outPH / 2 - hNegCount := outPH - outHHalf - if y < hNegCount { - negTableIdx := maxIdx - hNegCount + y - for i := 0; i < len(freqsH)*2; i++ { - row[idx+i] = negFreqsH[negTableIdx][i] - } - } else { - posIdx := y - hNegCount - for i := 0; i < len(freqsH)*2; i++ { - row[idx+i] = posFreqsH[posIdx][i] - } - } - idx += len(freqsH) * 2 - - // Width with scale_rope centering (using OUTPUT dimensions) - outWHalf := outPW / 2 - wNegCount := outPW - outWHalf - if x < wNegCount { - negTableIdx := maxIdx - wNegCount + x - for i := 0; i < len(freqsW)*2; i++ { - row[idx+i] = negFreqsW[negTableIdx][i] - } - } else { - posIdx := x - wNegCount - for i := 0; i < len(freqsW)*2; i++ { - row[idx+i] = posFreqsW[posIdx][i] - } - } - - return row - } - - // Helper to compute RoPE for frame -1 (used for last condition image) - // This matches Python's _compute_condition_freqs which uses freqs_neg[0][-1:] - computeNegFrameFreqs := func(y, x int32) []float32 { - row := make([]float32, headDim) - idx := 0 - - // Frame -1: use last row of negative frame frequencies - negFrameIdx := maxIdx - 1 - for i := 0; i < len(freqsT)*2; i++ { - row[idx+i] = negFreqsT[negFrameIdx][i] - } - idx += len(freqsT) * 2 - - // Height with scale_rope centering (using OUTPUT dimensions) - outHHalf := outPH / 2 - hNegCount := outPH - outHHalf - if y < hNegCount { - negTableIdx := maxIdx - hNegCount + y - for i := 0; i < len(freqsH)*2; i++ { - row[idx+i] = negFreqsH[negTableIdx][i] - } - } else { - posIdx := y - hNegCount - for i := 0; i < len(freqsH)*2; i++ { - row[idx+i] = posFreqsH[posIdx][i] - } - } - idx += len(freqsH) * 2 - - // Width with scale_rope centering (using OUTPUT dimensions) - outWHalf := outPW / 2 - wNegCount := outPW - outWHalf - if x < wNegCount { - negTableIdx := maxIdx - wNegCount + x - for i := 0; i < len(freqsW)*2; i++ { - row[idx+i] = negFreqsW[negTableIdx][i] - } - } else { - posIdx := x - wNegCount - for i := 0; i < len(freqsW)*2; i++ { - row[idx+i] = posFreqsW[posIdx][i] - } - } - - return row - } - - // Total image sequence length: noise + all input images - noiseSeqLen := outPH * outPW - totalImgLen := noiseSeqLen - for _, dims := range inputDims { - totalImgLen += dims.PatchH * dims.PatchW - } - - imgFreqsData := make([]float32, totalImgLen*headDim) - idx := int32(0) - - // Segment 0: Noise latents - standard RoPE at output resolution (frame 0) - for y := int32(0); y < outPH; y++ { - for x := int32(0); x < outPW; x++ { - row := computePosFreqs(0, y, x) - copy(imgFreqsData[idx:], row) - idx += headDim - } - } - - // Segments 1..N: Edit image latents - INTERPOLATED RoPE - // For single image: use frame 1 (matches original PrepareRoPEInterpolated) - // For multiple images: Python uses frame -1 for the LAST condition image - // (_compute_condition_freqs), positive indices for others. - numImages := len(inputDims) - lastImgIdx := numImages - 1 - for imgIdx, dims := range inputDims { - inPH := dims.PatchH - inPW := dims.PatchW - - // Determine frame index for this image - // Single image case: use frame 1 (like original PrepareRoPEInterpolated) - // Multi-image case: last image uses frame -1, others use frame 1, 2, etc. - useNegFrame := numImages > 1 && imgIdx == lastImgIdx - - // Map each input position to an output position using linear interpolation - for y := int32(0); y < inPH; y++ { - for x := int32(0); x < inPW; x++ { - // Interpolate: map input (y, x) to output grid position - // This is the key fix from DiffSynth's forward_sampling - var yOut, xOut int32 - if inPH == 1 { - yOut = 0 - } else { - // Linear interpolation: y_out = y * (outPH - 1) / (inPH - 1) - yOut = y * (outPH - 1) / (inPH - 1) - } - if inPW == 1 { - xOut = 0 - } else { - xOut = x * (outPW - 1) / (inPW - 1) - } - - var row []float32 - if useNegFrame { - // Last image in multi-image uses frame -1 - row = computeNegFrameFreqs(yOut, xOut) - } else { - // Single image uses frame 1, multi-image uses frame 1, 2, etc. - frameIdx := int32(imgIdx + 1) - row = computePosFreqs(frameIdx, yOut, xOut) - } - copy(imgFreqsData[idx:], row) - idx += headDim - } - } - } - - imgFreqs := mlx.NewArray(imgFreqsData, []int32{totalImgLen, headDim}) - imgFreqs = mlx.ToBFloat16(imgFreqs) - - // Text frequencies - start after max video index - maxVidIdx := max(outPH/2, outPW/2) - - txtFreqsData := make([]float32, txtLen*headDim) - idx = 0 - for t := int32(0); t < txtLen; t++ { - pos := maxVidIdx + t - for i := 0; i < len(freqsT)*2; i++ { - txtFreqsData[idx+int32(i)] = posFreqsT[pos][i] - } - idx += int32(len(freqsT) * 2) - for i := 0; i < len(freqsH)*2; i++ { - txtFreqsData[idx+int32(i)] = posFreqsH[pos][i] - } - idx += int32(len(freqsH) * 2) - for i := 0; i < len(freqsW)*2; i++ { - txtFreqsData[idx+int32(i)] = posFreqsW[pos][i] - } - idx += int32(len(freqsW) * 2) - } - - txtFreqs := mlx.NewArray(txtFreqsData, []int32{txtLen, headDim}) - txtFreqs = mlx.ToBFloat16(txtFreqs) - - return &qwen_image.RoPECache{ - ImgFreqs: imgFreqs, - TxtFreqs: txtFreqs, - } -} diff --git a/x/imagegen/models/qwen_image_edit/rope_test.go b/x/imagegen/models/qwen_image_edit/rope_test.go deleted file mode 100644 index 200940fbe..000000000 --- a/x/imagegen/models/qwen_image_edit/rope_test.go +++ /dev/null @@ -1,249 +0,0 @@ -//go:build mlx - -package qwen_image_edit - -import ( - "fmt" - "math" - "os" - "path/filepath" - "runtime" - "testing" - - "github.com/ollama/ollama/x/imagegen/mlx" - "github.com/ollama/ollama/x/imagegen/models/qwen_image" -) - -// TestMain initializes MLX before running tests. -// If MLX libraries are not available, tests are skipped. -func TestMain(m *testing.M) { - // Change to repo root so ./build/lib/ollama/ path works - _, thisFile, _, _ := runtime.Caller(0) - repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..", "..") - if err := os.Chdir(repoRoot); err != nil { - fmt.Printf("Failed to change to repo root: %v\n", err) - os.Exit(1) - } - - if err := mlx.InitMLX(); err != nil { - fmt.Printf("Skipping qwen_image_edit tests: %v\n", err) - os.Exit(0) - } - os.Exit(m.Run()) -} - -// TestComputeAxisFreqs verifies frequency computation matches Python reference -func TestComputeAxisFreqs(t *testing.T) { - theta := float64(10000) - - // Expected values from Python: - // freqs = 1.0 / (theta ** (np.arange(0, half_dim) / half_dim)) - expectedFreqsT := []float64{ - 1.000000000000000, 0.316227766016838, 0.100000000000000, 0.031622776601684, - 0.010000000000000, 0.003162277660168, 0.001000000000000, 0.000316227766017, - } - - expectedFreqsH_first4 := []float64{ - 1.000000000000000, 0.719685673001152, 0.517947467923121, 0.372759372031494, - } - - expectedFreqsH_last4 := []float64{ - 0.000372759372031, 0.000268269579528, 0.000193069772888, 0.000138949549437, - } - - // Test temporal frequencies (dim=16) - freqsT := qwen_image.ComputeAxisFreqs(16, theta) - if len(freqsT) != 8 { - t.Fatalf("expected 8 temporal frequencies, got %d", len(freqsT)) - } - for i, expected := range expectedFreqsT { - if diff := math.Abs(freqsT[i] - expected); diff > 1e-10 { - t.Errorf("freqsT[%d]: expected %.15f, got %.15f, diff %.2e", i, expected, freqsT[i], diff) - } - } - - // Test height/width frequencies (dim=56) - freqsH := qwen_image.ComputeAxisFreqs(56, theta) - if len(freqsH) != 28 { - t.Fatalf("expected 28 height frequencies, got %d", len(freqsH)) - } - for i, expected := range expectedFreqsH_first4 { - if diff := math.Abs(freqsH[i] - expected); diff > 1e-10 { - t.Errorf("freqsH[%d]: expected %.15f, got %.15f, diff %.2e", i, expected, freqsH[i], diff) - } - } - for i, expected := range expectedFreqsH_last4 { - idx := 24 + i // last 4 of 28 - if diff := math.Abs(freqsH[idx] - expected); diff > 1e-10 { - t.Errorf("freqsH[%d]: expected %.15f, got %.15f, diff %.2e", idx, expected, freqsH[idx], diff) - } - } -} - -// TestMakeFreqTable verifies the frequency lookup table for both positive and negative positions -func TestMakeFreqTable(t *testing.T) { - theta := float64(10000) - freqsT := qwen_image.ComputeAxisFreqs(16, theta) - maxIdx := int32(4096) - - // Test positive table - posTable := qwen_image.MakeFreqTable(maxIdx, freqsT, false) - - // Position 0 should give cos=1, sin=0 for all frequencies - for i := 0; i < len(freqsT)*2; i += 2 { - if posTable[0][i] != 1.0 { - t.Errorf("posTable[0][%d] (cos): expected 1.0, got %f", i, posTable[0][i]) - } - if posTable[0][i+1] != 0.0 { - t.Errorf("posTable[0][%d] (sin): expected 0.0, got %f", i+1, posTable[0][i+1]) - } - } - - // Position 1, first frequency (1.0): angle = 1*1 = 1 - // cos(1) = 0.5403, sin(1) = 0.8415 - if diff := math.Abs(float64(posTable[1][0]) - 0.5403023058681398); diff > 1e-6 { - t.Errorf("posTable[1][0] (cos): expected 0.5403, got %f", posTable[1][0]) - } - if diff := math.Abs(float64(posTable[1][1]) - 0.8414709848078965); diff > 1e-6 { - t.Errorf("posTable[1][1] (sin): expected 0.8415, got %f", posTable[1][1]) - } - - // Test negative table - negTable := qwen_image.MakeFreqTable(maxIdx, freqsT, true) - - // negTable[4095] corresponds to position -1 - // cos(-1) = cos(1), sin(-1) = -sin(1) - if diff := math.Abs(float64(negTable[4095][0]) - 0.5403023058681398); diff > 1e-6 { - t.Errorf("negTable[4095][0] (cos(-1)): expected 0.5403, got %f", negTable[4095][0]) - } - if diff := math.Abs(float64(negTable[4095][1]) - (-0.8414709848078965)); diff > 1e-6 { - t.Errorf("negTable[4095][1] (sin(-1)): expected -0.8415, got %f", negTable[4095][1]) - } - - // negTable[4094] corresponds to position -2 - // cos(-2) = cos(2), sin(-2) = -sin(2) - cos2 := math.Cos(2.0) - sin2 := math.Sin(2.0) - if diff := math.Abs(float64(negTable[4094][0]) - cos2); diff > 1e-6 { - t.Errorf("negTable[4094][0] (cos(-2)): expected %f, got %f", cos2, negTable[4094][0]) - } - if diff := math.Abs(float64(negTable[4094][1]) - (-sin2)); diff > 1e-6 { - t.Errorf("negTable[4094][1] (sin(-2)): expected %f, got %f", -sin2, negTable[4094][1]) - } -} - -// TestPrepareRoPE_QwenImage verifies qwen_image.PrepareRoPE for single-segment case -func TestPrepareRoPE_QwenImage(t *testing.T) { - if !mlx.GPUIsAvailable() { - t.Skip("GPU not available") - } - - mlx.SetDefaultDeviceCPU() - - // 4x4 patch grid, single image - imgH, imgW := int32(4), int32(4) - txtLen := int32(5) - axesDims := []int32{16, 56, 56} - - cache := qwen_image.PrepareRoPE(imgH, imgW, txtLen, axesDims) - mlx.Eval(cache.ImgFreqs, cache.TxtFreqs) - - // Check shapes - imgShape := cache.ImgFreqs.Shape() - if imgShape[0] != 16 { // 4*4 patches - t.Errorf("ImgFreqs seq len: expected 16, got %d", imgShape[0]) - } - - // For single image (frame=0), all temporal values should be cos=1, sin=0 - imgFreqsCPU := mlx.AsType(cache.ImgFreqs, mlx.DtypeFloat32) - mlx.Eval(imgFreqsCPU) - imgData := imgFreqsCPU.Data() - - // Check first 16 values of patch 0 (temporal cos/sin pairs) - for i := 0; i < 16; i += 2 { - cosVal := imgData[i] - sinVal := imgData[i+1] - if diff := math.Abs(float64(cosVal - 1.0)); diff > 1e-5 { - t.Errorf("ImgFreqs[0][%d] (cos): expected 1.0, got %f", i, cosVal) - } - if diff := math.Abs(float64(sinVal - 0.0)); diff > 1e-5 { - t.Errorf("ImgFreqs[0][%d] (sin): expected 0.0, got %f", i+1, sinVal) - } - } - - cache.ImgFreqs.Free() - cache.TxtFreqs.Free() -} - -// TestScaleRopePositions verifies the centered position calculation for scale_rope=True -func TestScaleRopePositions(t *testing.T) { - // For a 4x4 grid with scale_rope=True: - // hHalf = 2, wHalf = 2 - // hNegCount = 4 - 2 = 2 (positions 0,1 are negative) - // wNegCount = 4 - 2 = 2 (positions 0,1 are negative) - // - // Height positions: - // y=0: -(4-2) + 0 = -2 - // y=1: -(4-2) + 1 = -1 - // y=2: 2 - 2 = 0 - // y=3: 3 - 2 = 1 - // - // Same for width - - pH, pW := int32(4), int32(4) - hHalf := pH / 2 - wHalf := pW / 2 - hNegCount := pH - hHalf - wNegCount := pW - wHalf - - expectedH := []int32{-2, -1, 0, 1} - expectedW := []int32{-2, -1, 0, 1} - - for y := int32(0); y < pH; y++ { - var hPos int32 - if y < hNegCount { - hPos = -(pH - hHalf) + y - } else { - hPos = y - hNegCount - } - if hPos != expectedH[y] { - t.Errorf("y=%d: expected h_pos=%d, got %d", y, expectedH[y], hPos) - } - } - - for x := int32(0); x < pW; x++ { - var wPos int32 - if x < wNegCount { - wPos = -(pW - wHalf) + x - } else { - wPos = x - wNegCount - } - if wPos != expectedW[x] { - t.Errorf("x=%d: expected w_pos=%d, got %d", x, expectedW[x], wPos) - } - } -} - -// TestRoPEHeadDimensions verifies the head dimension breakdown -func TestRoPEHeadDimensions(t *testing.T) { - // axes_dims_rope = [16, 56, 56] - // Each dimension uses half the values for frequencies - // So we get: 8 + 28 + 28 = 64 frequency values - // Each frequency produces cos + sin, so: 64 * 2 = 128 total values per position - - axesDims := []int32{16, 56, 56} - expectedFreqs := (axesDims[0]/2 + axesDims[1]/2 + axesDims[2]/2) - expectedHeadDim := expectedFreqs * 2 - - if expectedFreqs != 64 { - t.Errorf("expected 64 frequency values, got %d", expectedFreqs) - } - if expectedHeadDim != 128 { - t.Errorf("expected head_dim=128, got %d", expectedHeadDim) - } - - // This should match the transformer's attention head dimension - // hidden_size = 3072, num_heads = 24 - // head_dim = 3072 / 24 = 128 -} - diff --git a/x/imagegen/models/qwen_image_edit/vae.go b/x/imagegen/models/qwen_image_edit/vae.go deleted file mode 100644 index 3dbe7ef3c..000000000 --- a/x/imagegen/models/qwen_image_edit/vae.go +++ /dev/null @@ -1,642 +0,0 @@ -//go:build mlx - -package qwen_image_edit - -import ( - "fmt" - - "github.com/ollama/ollama/x/imagegen/mlx" - "github.com/ollama/ollama/x/imagegen/safetensors" -) - -// VAEConfig holds Qwen-Image VAE configuration -type VAEConfig struct { - ZDim int32 `json:"z_dim"` // 16 - BaseDim int32 `json:"base_dim"` // 96 - DimMult []int32 `json:"dim_mult"` // [1, 2, 4, 4] - NumResBlocks int32 `json:"num_res_blocks"` // 2 - LatentsMean []float32 `json:"latents_mean"` // 16 values - LatentsStd []float32 `json:"latents_std"` // 16 values - TemperalDownsample []bool `json:"temperal_downsample"` // [false, true, true] -} - -// defaultVAEConfig returns config for Qwen-Image VAE -func defaultVAEConfig() *VAEConfig { - return &VAEConfig{ - ZDim: 16, - BaseDim: 96, - DimMult: []int32{1, 2, 4, 4}, - NumResBlocks: 2, - LatentsMean: []float32{ - -0.7571, -0.7089, -0.9113, 0.1075, - -0.1745, 0.9653, -0.1517, 1.5508, - 0.4134, -0.0715, 0.5517, -0.3632, - -0.1922, -0.9497, 0.2503, -0.2921, - }, - LatentsStd: []float32{ - 2.8184, 1.4541, 2.3275, 2.6558, - 1.2196, 1.7708, 2.6052, 2.0743, - 3.2687, 2.1526, 2.8652, 1.5579, - 1.6382, 1.1253, 2.8251, 1.916, - }, - TemperalDownsample: []bool{false, true, true}, - } -} - -// VAE is the full VAE with encoder and decoder -type VAE struct { - Config *VAEConfig - Encoder *VAEEncoder - Decoder *VAEDecoder -} - -// Load loads the VAE from a directory -func (m *VAE) Load(path string) error { - fmt.Println("Loading Qwen-Image-Edit VAE (encoder + decoder)...") - - cfg := defaultVAEConfig() - m.Config = cfg - - weights, err := safetensors.LoadModelWeights(path) - if err != nil { - return fmt.Errorf("weights: %w", err) - } - - // Load weights as f32 for quality (matches Python default behavior) - // VAE decoder precision is critical for final image quality - fmt.Print(" Loading weights as f32... ") - if err := weights.Load(mlx.DtypeFloat32); err != nil { - return fmt.Errorf("failed to load weights: %w", err) - } - fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024)) - - // Load encoder - fmt.Print(" Loading encoder... ") - m.Encoder = &VAEEncoder{} - if err := m.Encoder.loadFromWeights(weights, cfg); err != nil { - return fmt.Errorf("encoder: %w", err) - } - fmt.Println("✓") - - // Load decoder - fmt.Print(" Loading decoder... ") - m.Decoder = &VAEDecoder{} - if err := m.Decoder.loadFromWeights(weights, cfg); err != nil { - return fmt.Errorf("decoder: %w", err) - } - fmt.Println("✓") - - weights.ReleaseAll() - return nil -} - -// Encode encodes an image to latents -// x: [B, C, T, H, W] image tensor in [-1, 1] range -// Returns: [B, C, T, H/8, W/8] latents (unnormalized) -func (m *VAE) Encode(x *mlx.Array) *mlx.Array { - return m.Encoder.Encode(x) -} - -// Decode decodes latents to image -// z: [B, C, T, H, W] latents (denormalized) -// Returns: [B, C, T, H*8, W*8] image in [-1, 1] -func (m *VAE) Decode(z *mlx.Array) *mlx.Array { - return m.Decoder.Decode(z) -} - -// Normalize applies latent normalization -// Input z should be f32 (from VAE encoder), output is f32 for transformer -func (m *VAE) Normalize(z *mlx.Array) *mlx.Array { - shape := z.Shape() - C := shape[1] - - mean := mlx.NewArray(m.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1}) - std := mlx.NewArray(m.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1}) - - // Mean/std are f32, will match z dtype through broadcasting - return mlx.Div(mlx.Sub(z, mean), std) -} - -// Denormalize reverses latent normalization -// Input z is bf16 (from transformer), output converted to f32 for VAE decoder -func (m *VAE) Denormalize(z *mlx.Array) *mlx.Array { - shape := z.Shape() - C := shape[1] - - // Convert latents to f32 for VAE decoder quality - z = mlx.AsType(z, mlx.DtypeFloat32) - - mean := mlx.NewArray(m.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1}) - std := mlx.NewArray(m.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1}) - - return mlx.Add(mlx.Mul(z, std), mean) -} - -// VAEEncoder is the encoder part of the VAE -// The encoder uses a flat structure where down_blocks contains a mix of ResBlocks and Downsamplers: -// - Blocks 0,1: ResBlocks (base_dim) -// - Block 2: Downsample -// - Blocks 3,4: ResBlocks (base_dim*2) -// - Block 5: Downsample + temporal -// - Blocks 6,7: ResBlocks (base_dim*4) -// - Block 8: Downsample + temporal -// - Blocks 9,10: ResBlocks (base_dim*4) -type VAEEncoder struct { - Config *VAEConfig - - ConvIn *CausalConv3d - Blocks []EncoderBlock // Flat list of ResBlocks and Downsamplers - MidBlock *MidBlock - NormOut *RMSNorm3D - ConvOut *CausalConv3d - QuantConv *CausalConv3d -} - -// EncoderBlock is either a ResBlock or a Downsample -type EncoderBlock interface { - Forward(x *mlx.Array) *mlx.Array - IsDownsample() bool -} - -// EncoderResBlock wraps ResBlock -type EncoderResBlock struct { - *ResBlock -} - -func (b *EncoderResBlock) IsDownsample() bool { return false } - -// EncoderDownsample is a downsample layer -type EncoderDownsample struct { - Resample *CausalConv3d - TimeConv *CausalConv3d // Optional temporal downsample -} - -func (d *EncoderDownsample) IsDownsample() bool { return true } - -func (d *EncoderDownsample) Forward(x *mlx.Array) *mlx.Array { - // Spatial downsample with stride 2 - // WAN VAE uses: ZeroPad2d(0,1,0,1) + Conv2d(3x3, stride=2) - x = d.forwardSpatialDownsample(x) - - // NOTE: In WAN VAE, time_conv is ONLY used in streaming/chunked mode - // with feat_cache. For single-frame encoding (T=1), time_conv is skipped. - // The Python forward checks: if feat_cache is not None ... then use time_conv - // Since we don't support streaming, we skip time_conv entirely. - return x -} - -// forwardSpatialDownsample applies 2D conv with stride 2 for spatial downsampling -func (d *EncoderDownsample) forwardSpatialDownsample(x *mlx.Array) *mlx.Array { - xShape := x.Shape() - B := xShape[0] - T := xShape[1] - H := xShape[2] - W := xShape[3] - C := xShape[4] - - wShape := d.Resample.Weight.Shape() - outC := wShape[0] - - // Reshape to [B*T, H, W, C] for 2D conv - x = mlx.Reshape(x, B*T, H, W, C) - - // Asymmetric padding: pad right and bottom by 1 (WAN VAE style) - // ZeroPad2d(0, 1, 0, 1) means (left=0, right=1, top=0, bottom=1) - x = mlx.Pad(x, []int32{0, 0, 0, 1, 0, 1, 0, 0}) // [B, H, W, C] -> pad H and W - - // Apply 2D conv with stride 2 - weight := mlx.Transpose(d.Resample.Weight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I] - x = conv2DStrided(x, weight, 2) - - if d.Resample.Bias != nil { - bias := mlx.Reshape(d.Resample.Bias, 1, 1, 1, outC) - x = mlx.Add(x, bias) - } - - // Output dims after stride 2: (H+1)/2, (W+1)/2 - outH := (H + 1) / 2 - outW := (W + 1) / 2 - - // Reshape back to [B, T, H', W', C] - x = mlx.Reshape(x, B, T, outH, outW, outC) - mlx.Eval(x) - - return x -} - -// loadFromWeights loads the encoder from pre-loaded weights -func (e *VAEEncoder) loadFromWeights(weights *safetensors.ModelWeights, cfg *VAEConfig) error { - e.Config = cfg - - // Conv in - convIn, err := newCausalConv3d(weights, "encoder.conv_in") - if err != nil { - return err - } - e.ConvIn = convIn - - // Encoder uses flat block structure: - // dim_mult = [1, 2, 4, 4], num_res_blocks = 2, temporal_downsample = [false, true, true] - // Block layout: res,res,down, res,res,down+t, res,res,down+t, res,res - // That's 11 blocks: 0,1=res, 2=down, 3,4=res, 5=down+t, 6,7=res, 8=down+t, 9,10=res - e.Blocks = make([]EncoderBlock, 0, 11) - - // Track dimensions - dims := []int32{cfg.BaseDim, cfg.BaseDim * 2, cfg.BaseDim * 4, cfg.BaseDim * 4} - blockIdx := 0 - - for stage := 0; stage < len(cfg.DimMult); stage++ { - inDim := cfg.BaseDim - if stage > 0 { - inDim = dims[stage-1] - } - outDim := dims[stage] - - // ResBlocks for this stage (num_res_blocks per stage) - for r := int32(0); r < cfg.NumResBlocks; r++ { - prefix := fmt.Sprintf("encoder.down_blocks.%d", blockIdx) - currentInDim := inDim - if r > 0 { - currentInDim = outDim - } - block, err := newEncoderResBlock(weights, prefix, currentInDim, outDim) - if err != nil { - return fmt.Errorf("encoder res block %d: %w", blockIdx, err) - } - e.Blocks = append(e.Blocks, block) - blockIdx++ - } - - // Downsample after each stage except the last - if stage < len(cfg.DimMult)-1 { - prefix := fmt.Sprintf("encoder.down_blocks.%d", blockIdx) - down, err := newEncoderDownsample(weights, prefix, cfg.TemperalDownsample[stage]) - if err != nil { - return fmt.Errorf("encoder downsample %d: %w", blockIdx, err) - } - e.Blocks = append(e.Blocks, down) - blockIdx++ - } - } - - // Mid block - midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1] - midBlock, err := newMidBlock(weights, "encoder.mid_block", midDim) - if err != nil { - return err - } - e.MidBlock = midBlock - - // Norm out - normOut, err := newRMSNorm3D(weights, "encoder.norm_out", midDim) - if err != nil { - return err - } - e.NormOut = normOut - - // Conv out - convOut, err := newCausalConv3d(weights, "encoder.conv_out") - if err != nil { - return err - } - e.ConvOut = convOut - - // Quant conv - quantConv, err := newCausalConv3d(weights, "quant_conv") - if err != nil { - return err - } - e.QuantConv = quantConv - - return nil -} - -// newEncoderResBlock creates a ResBlock for the encoder (flat structure) -func newEncoderResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*EncoderResBlock, error) { - block, err := newResBlock(weights, prefix, inDim, outDim) - if err != nil { - return nil, err - } - return &EncoderResBlock{block}, nil -} - -// newEncoderDownsample creates a downsample layer for the encoder -func newEncoderDownsample(weights *safetensors.ModelWeights, prefix string, temporal bool) (*EncoderDownsample, error) { - resample, err := newCausalConv3d(weights, prefix+".resample.1") - if err != nil { - return nil, err - } - - var timeConv *CausalConv3d - if temporal { - timeConv, _ = newCausalConv3d(weights, prefix+".time_conv") - } - - return &EncoderDownsample{ - Resample: resample, - TimeConv: timeConv, - }, nil -} - -// Encode encodes an image to latents -// x: [B, C, T, H, W] image tensor (channels-first) -// Returns: [B, latent_C, T, H/8, W/8] latent distribution mode -func (e *VAEEncoder) Encode(x *mlx.Array) *mlx.Array { - // Convert from channels-first [N, C, T, H, W] to channels-last [N, T, H, W, C] - x = mlx.Contiguous(mlx.Transpose(x, 0, 2, 3, 4, 1)) - mlx.Eval(x) - - // Conv in - x = e.ConvIn.Forward(x) - - // Encoder blocks (mix of ResBlocks and Downsamplers) - for _, block := range e.Blocks { - prev := x - x = block.Forward(x) - prev.Free() - } - - // Mid block - x = e.MidBlock.Forward(x) - - // Norm + silu - { - prev := x - x = e.NormOut.Forward(x) - x = silu3D(x) - prev.Free() - mlx.Eval(x) - } - - // Conv out - { - prev := x - x = e.ConvOut.Forward(x) - prev.Free() - } - - // Quant conv - { - prev := x - x = e.QuantConv.Forward(x) - prev.Free() - } - - // Get mode from distribution (first half of channels = mean) - // Output is [B, T, H, W, 2*latent_C], we take first latent_C channels - shape := x.Shape() - latentC := shape[4] / 2 - x = mlx.Slice(x, []int32{0, 0, 0, 0, 0}, []int32{shape[0], shape[1], shape[2], shape[3], latentC}) - - // Convert back to channels-first [N, C, T, H, W] - x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3)) - mlx.Eval(x) - - return x -} - -// VAEDecoder is the decoder part of the VAE -type VAEDecoder struct { - Config *VAEConfig - - PostQuantConv *CausalConv3d - ConvIn *CausalConv3d - MidBlock *MidBlock - UpBlocks []*UpBlock - NormOut *RMSNorm3D - ConvOut *CausalConv3d -} - -// loadFromWeights loads the decoder from pre-loaded weights -func (d *VAEDecoder) loadFromWeights(weights *safetensors.ModelWeights, cfg *VAEConfig) error { - d.Config = cfg - - postQuantConv, err := newCausalConv3d(weights, "post_quant_conv") - if err != nil { - return err - } - d.PostQuantConv = postQuantConv - - convIn, err := newCausalConv3d(weights, "decoder.conv_in") - if err != nil { - return err - } - d.ConvIn = convIn - - // Mid block - midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1] - midBlock, err := newMidBlock(weights, "decoder.mid_block", midDim) - if err != nil { - return err - } - d.MidBlock = midBlock - - // Up blocks (reversed dim_mult) - numUpBlocks := len(cfg.DimMult) - d.UpBlocks = make([]*UpBlock, numUpBlocks) - - dimsMult := make([]int32, numUpBlocks+1) - dimsMult[0] = cfg.DimMult[numUpBlocks-1] - for i := 0; i < numUpBlocks; i++ { - dimsMult[i+1] = cfg.DimMult[numUpBlocks-1-i] - } - - temporalUpsample := make([]bool, len(cfg.TemperalDownsample)) - for i := range cfg.TemperalDownsample { - temporalUpsample[i] = cfg.TemperalDownsample[len(cfg.TemperalDownsample)-1-i] - } - - for i := 0; i < numUpBlocks; i++ { - inDim := cfg.BaseDim * dimsMult[i] - outDim := cfg.BaseDim * dimsMult[i+1] - - if i > 0 { - inDim = inDim / 2 - } - - upsampleMode := "" - if i < numUpBlocks-1 { - if temporalUpsample[i] { - upsampleMode = "upsample3d" - } else { - upsampleMode = "upsample2d" - } - } - - prefix := fmt.Sprintf("decoder.up_blocks.%d", i) - upBlock, err := newUpBlock(weights, prefix, inDim, outDim, cfg.NumResBlocks, upsampleMode) - if err != nil { - return err - } - d.UpBlocks[i] = upBlock - } - - normOut, err := newRMSNorm3D(weights, "decoder.norm_out", cfg.BaseDim) - if err != nil { - return err - } - d.NormOut = normOut - - convOut, err := newCausalConv3d(weights, "decoder.conv_out") - if err != nil { - return err - } - d.ConvOut = convOut - - return nil -} - -// Decode converts latents to image -// z: [B, C, T, H, W] denormalized latents -func (d *VAEDecoder) Decode(z *mlx.Array) *mlx.Array { - var x *mlx.Array - - // Convert from channels-first to channels-last - { - z = mlx.Contiguous(mlx.Transpose(z, 0, 2, 3, 4, 1)) - mlx.Eval(z) - } - - // PostQuantConv - x = d.PostQuantConv.Forward(z) - z.Free() - - // ConvIn - { - prev := x - x = d.ConvIn.Forward(x) - prev.Free() - } - - // Mid block - x = d.MidBlock.Forward(x) - - // Up blocks - for _, upBlock := range d.UpBlocks { - x = upBlock.Forward(x) - } - - // NormOut + silu - { - prev := x - x = d.NormOut.Forward(x) - x = silu3D(x) - prev.Free() - mlx.Eval(x) - } - - // ConvOut - { - prev := x - x = d.ConvOut.Forward(x) - prev.Free() - } - - // Post-processing: clamp and convert back to channels-first - { - prev := x - x = mlx.ClipScalar(x, -1.0, 1.0, true, true) - x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3)) - prev.Free() - mlx.Eval(x) - } - - return x -} - -// DownBlock handles downsampling in encoder -type DownBlock struct { - ResBlocks []*ResBlock - Downsampler *Downsample -} - -// newDownBlock creates a down block -func newDownBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, downsampleMode string) (*DownBlock, error) { - resBlocks := make([]*ResBlock, numBlocks+1) - - currentDim := inDim - for i := int32(0); i <= numBlocks; i++ { - resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i) - block, err := newResBlock(weights, resPrefix, currentDim, outDim) - if err != nil { - return nil, err - } - resBlocks[i] = block - currentDim = outDim - } - - var downsampler *Downsample - if downsampleMode != "" { - downsampler = newDownsample(weights, prefix+".downsamplers.0", outDim, downsampleMode) - } - - return &DownBlock{ - ResBlocks: resBlocks, - Downsampler: downsampler, - }, nil -} - -// Forward applies down block -func (d *DownBlock) Forward(x *mlx.Array) *mlx.Array { - for _, block := range d.ResBlocks { - prev := x - x = block.Forward(x) - prev.Free() - } - - if d.Downsampler != nil { - prev := x - x = d.Downsampler.Forward(x) - prev.Free() - } - return x -} - -// Downsample handles spatial downsampling -type Downsample struct { - Conv *mlx.Array - Bias *mlx.Array - Mode string -} - -// newDownsample creates a downsampler -func newDownsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Downsample { - conv, _ := weights.Get(prefix + ".resample.1.weight") - bias, _ := weights.Get(prefix + ".resample.1.bias") - return &Downsample{ - Conv: conv, - Bias: bias, - Mode: mode, - } -} - -// Forward applies downsampling to channels-last input [B, T, H, W, C] -func (d *Downsample) Forward(x *mlx.Array) *mlx.Array { - shape := x.Shape() - B := shape[0] - T := shape[1] - H := shape[2] - W := shape[3] - C := shape[4] - outC := d.Conv.Shape()[0] - - // Reshape to [B*T, H, W, C] for 2D conv - x = mlx.Reshape(x, B*T, H, W, C) - - // Pad for stride-2 conv: need (3-1)/2 = 1 on each side, but for stride 2 we need specific padding - // For 3x3 stride 2: pad 1 on all sides - x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0}) - - // Conv with stride 2 using manual strided patching - weight := mlx.Transpose(d.Conv, 0, 2, 3, 1) - x = conv2DStrided(x, weight, 2) - if d.Bias != nil { - bias := mlx.Reshape(d.Bias, 1, 1, 1, outC) - x = mlx.Add(x, bias) - } - - x = mlx.Reshape(x, B, T, H/2, W/2, outC) - mlx.Eval(x) - - return x -}