diff --git a/middleware/openai.go b/middleware/openai.go index beaa9ee97..e2db8f965 100644 --- a/middleware/openai.go +++ b/middleware/openai.go @@ -609,3 +609,49 @@ func ImageGenerationsMiddleware() gin.HandlerFunc { c.Next() } } + +func ImageEditsMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + var req openai.ImageEditRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error())) + return + } + + if req.Prompt == "" { + c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "prompt is required")) + return + } + + if req.Model == "" { + c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required")) + return + } + + if req.Image == "" { + c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "image is required")) + return + } + + genReq, err := openai.FromImageEditRequest(req) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error())) + return + } + + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(genReq); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error())) + return + } + + c.Request.Body = io.NopCloser(&b) + + w := &ImageWriter{ + BaseWriter: BaseWriter{ResponseWriter: c.Writer}, + } + + c.Writer = w + c.Next() + } +} diff --git a/middleware/openai_test.go b/middleware/openai_test.go index cc7c3c215..0cf2558ae 100644 --- a/middleware/openai_test.go +++ b/middleware/openai_test.go @@ -1112,3 +1112,129 @@ func TestImageWriterResponse(t *testing.T) { t.Errorf("expected image data 'dGVzdC1pbWFnZS1kYXRh', got %s", imageResp.Data[0].B64JSON) } } + +func TestImageEditsMiddleware(t *testing.T) { + type testCase struct { + name string + body string + req api.GenerateRequest + err openai.ErrorResponse + } + + var capturedRequest *api.GenerateRequest + + // Base64-encoded test image (1x1 pixel PNG) + testImage := "" + decodedImage, _ := base64.StdEncoding.DecodeString("iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=") + + testCases := []testCase{ + { + name: "image edit basic", + body: `{ + "model": "test-model", + "prompt": "make it blue", + "image": "` + testImage + `" + }`, + req: api.GenerateRequest{ + Model: "test-model", + Prompt: "make it blue", + Images: []api.ImageData{decodedImage}, + }, + }, + { + name: "image edit with size", + body: `{ + "model": "test-model", + "prompt": "make it blue", + "image": "` + testImage + `", + "size": "512x768" + }`, + req: api.GenerateRequest{ + Model: "test-model", + Prompt: "make it blue", + Images: []api.ImageData{decodedImage}, + Width: 512, + Height: 768, + }, + }, + { + name: "image edit missing prompt", + body: `{ + "model": "test-model", + "image": "` + testImage + `" + }`, + err: openai.ErrorResponse{ + Error: openai.Error{ + Message: "prompt is required", + Type: "invalid_request_error", + }, + }, + }, + { + name: "image edit missing model", + body: `{ + "prompt": "make it blue", + "image": "` + testImage + `" + }`, + err: openai.ErrorResponse{ + Error: openai.Error{ + Message: "model is required", + Type: "invalid_request_error", + }, + }, + }, + { + name: "image edit missing image", + body: `{ + "model": "test-model", + "prompt": "make it blue" + }`, + err: openai.ErrorResponse{ + Error: openai.Error{ + Message: "image is required", + Type: "invalid_request_error", + }, + }, + }, + } + + endpoint := func(c *gin.Context) { + c.Status(http.StatusOK) + } + + gin.SetMode(gin.TestMode) + router := gin.New() + router.Use(ImageEditsMiddleware(), captureRequestMiddleware(&capturedRequest)) + router.Handle(http.MethodPost, "/api/generate", endpoint) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body)) + req.Header.Set("Content-Type", "application/json") + + defer func() { capturedRequest = nil }() + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if tc.err.Error.Message != "" { + var errResp openai.ErrorResponse + if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(tc.err, errResp); diff != "" { + t.Fatalf("errors did not match:\n%s", diff) + } + return + } + + if resp.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String()) + } + + if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" { + t.Fatalf("requests did not match:\n%s", diff) + } + }) + } +} diff --git a/openai/openai.go b/openai/openai.go index d1f75c4aa..acc755354 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -794,3 +794,47 @@ func ToImageGenerationResponse(resp api.GenerateResponse) ImageGenerationRespons Data: data, } } + +// ImageEditRequest is an OpenAI-compatible image edit request. +type ImageEditRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + Image string `json:"image"` // Base64-encoded image data + Size string `json:"size,omitempty"` // e.g., "1024x1024" + Seed *int64 `json:"seed,omitempty"` +} + +// FromImageEditRequest converts an OpenAI image edit request to an Ollama GenerateRequest. +func FromImageEditRequest(r ImageEditRequest) (api.GenerateRequest, error) { + req := api.GenerateRequest{ + Model: r.Model, + Prompt: r.Prompt, + } + + // Decode the input image + if r.Image != "" { + imgData, err := decodeImageURL(r.Image) + if err != nil { + return api.GenerateRequest{}, fmt.Errorf("invalid image: %w", err) + } + req.Images = append(req.Images, imgData) + } + + // Parse size if provided (e.g., "1024x768") + if r.Size != "" { + var w, h int32 + if _, err := fmt.Sscanf(r.Size, "%dx%d", &w, &h); err == nil { + req.Width = w + req.Height = h + } + } + + if r.Seed != nil { + if req.Options == nil { + req.Options = map[string]any{} + } + req.Options["seed"] = *r.Seed + } + + return req, nil +} diff --git a/openai/openai_test.go b/openai/openai_test.go index f76af7090..b2e98ead4 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -448,3 +448,86 @@ func TestFromChatRequest_TopLogprobsRange(t *testing.T) { }) } } + +func TestFromImageEditRequest_Basic(t *testing.T) { + req := ImageEditRequest{ + Model: "test-model", + Prompt: "make it blue", + Image: prefix + image, + } + + result, err := FromImageEditRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Model != "test-model" { + t.Errorf("expected model 'test-model', got %q", result.Model) + } + + if result.Prompt != "make it blue" { + t.Errorf("expected prompt 'make it blue', got %q", result.Prompt) + } + + if len(result.Images) != 1 { + t.Fatalf("expected 1 image, got %d", len(result.Images)) + } +} + +func TestFromImageEditRequest_WithSize(t *testing.T) { + req := ImageEditRequest{ + Model: "test-model", + Prompt: "make it blue", + Image: prefix + image, + Size: "512x768", + } + + result, err := FromImageEditRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Width != 512 { + t.Errorf("expected width 512, got %d", result.Width) + } + + if result.Height != 768 { + t.Errorf("expected height 768, got %d", result.Height) + } +} + +func TestFromImageEditRequest_WithSeed(t *testing.T) { + seed := int64(12345) + req := ImageEditRequest{ + Model: "test-model", + Prompt: "make it blue", + Image: prefix + image, + Seed: &seed, + } + + result, err := FromImageEditRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Options == nil { + t.Fatal("expected options to be set") + } + + if result.Options["seed"] != seed { + t.Errorf("expected seed %d, got %v", seed, result.Options["seed"]) + } +} + +func TestFromImageEditRequest_InvalidImage(t *testing.T) { + req := ImageEditRequest{ + Model: "test-model", + Prompt: "make it blue", + Image: "not-valid-base64", + } + + _, err := FromImageEditRequest(req) + if err == nil { + t.Error("expected error for invalid image") + } +} diff --git a/server/images.go b/server/images.go index 2955d26f7..05dfe1468 100644 --- a/server/images.go +++ b/server/images.go @@ -75,12 +75,6 @@ type Model struct { func (m *Model) Capabilities() []model.Capability { capabilities := []model.Capability{} - // Check for image generation model via config capabilities - if slices.Contains(m.Config.Capabilities, "image") { - return []model.Capability{model.CapabilityImage} - } - - // Check for completion capability if m.ModelPath != "" { f, err := gguf.Open(m.ModelPath) if err == nil { diff --git a/server/images_test.go b/server/images_test.go index 9e581c8c3..639cf8662 100644 --- a/server/images_test.go +++ b/server/images_test.go @@ -56,6 +56,15 @@ func TestModelCapabilities(t *testing.T) { }, expectedCaps: []model.Capability{model.CapabilityImage}, }, + { + name: "model with image and vision capability (image editing)", + model: Model{ + Config: model.ConfigV2{ + Capabilities: []string{"image", "vision"}, + }, + }, + expectedCaps: []model.Capability{model.CapabilityImage, model.CapabilityVision}, + }, { name: "model with completion capability", model: Model{ diff --git a/server/routes.go b/server/routes.go index 2ecf64869..383e9b29c 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1604,8 +1604,9 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler) r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler) r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler) - // OpenAI-compatible image generation endpoint + // OpenAI-compatible image generation endpoints r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler) + r.POST("/v1/images/edits", middleware.ImageEditsMiddleware(), s.GenerateHandler) // Inference (Anthropic compatibility) r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler) @@ -2523,6 +2524,11 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo } } + var images []llm.ImageData + for i, imgData := range req.Images { + images = append(images, llm.ImageData{ID: i, Data: imgData}) + } + var streamStarted bool if err := runner.Completion(c.Request.Context(), llm.CompletionRequest{ Prompt: req.Prompt, @@ -2530,6 +2536,7 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo Height: req.Height, Steps: req.Steps, Seed: seed, + Images: images, }, func(cr llm.CompletionResponse) { streamStarted = true res := api.GenerateResponse{ diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index cad721f8b..ca149641a 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -2193,3 +2193,157 @@ func TestGenerateUnload(t *testing.T) { } }) } + +func TestGenerateWithImages(t *testing.T) { + gin.SetMode(gin.TestMode) + + mock := mockRunner{ + CompletionResponse: llm.CompletionResponse{ + Done: true, + DoneReason: llm.DoneReasonStop, + PromptEvalCount: 1, + PromptEvalDuration: 1, + EvalCount: 1, + EvalDuration: 1, + }, + } + + s := Server{ + sched: &Scheduler{ + pendingReqCh: make(chan *LlmRequest, 1), + finishedReqCh: make(chan *LlmRequest, 1), + expiredCh: make(chan *runnerRef, 1), + unloadedCh: make(chan any, 1), + loaded: make(map[string]*runnerRef), + newServerFn: newMockServer(&mock), + getGpuFn: getGpuFn, + getSystemInfoFn: getSystemInfoFn, + waitForRecovery: 250 * time.Millisecond, + loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { + time.Sleep(time.Millisecond) + req.successCh <- &runnerRef{ + llama: &mock, + } + return false + }, + }, + } + + go s.sched.Run(t.Context()) + + _, digest := createBinFile(t, ggml.KV{ + "general.architecture": "llama", + "llama.block_count": uint32(1), + "llama.context_length": uint32(8192), + "llama.embedding_length": uint32(4096), + "llama.attention.head_count": uint32(32), + "llama.attention.head_count_kv": uint32(8), + "tokenizer.ggml.tokens": []string{""}, + "tokenizer.ggml.scores": []float32{0}, + "tokenizer.ggml.token_type": []int32{0}, + }, []*ggml.Tensor{ + {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + }) + + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "test", + Files: map[string]string{"file.gguf": digest}, + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + t.Run("images passed to completion request", func(t *testing.T) { + testImage := []byte("test-image-data") + + mock.CompletionResponse.Content = "Image processed" + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "test", + Prompt: "Describe this image", + Images: []api.ImageData{testImage}, + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + // Verify images were passed to the completion request + if len(mock.CompletionRequest.Images) != 1 { + t.Fatalf("expected 1 image in completion request, got %d", len(mock.CompletionRequest.Images)) + } + + if !bytes.Equal(mock.CompletionRequest.Images[0].Data, testImage) { + t.Errorf("image data mismatch in completion request") + } + + if mock.CompletionRequest.Images[0].ID != 0 { + t.Errorf("expected image ID 0, got %d", mock.CompletionRequest.Images[0].ID) + } + }) + + t.Run("multiple images passed to completion request", func(t *testing.T) { + testImage1 := []byte("test-image-1") + testImage2 := []byte("test-image-2") + + mock.CompletionResponse.Content = "Images processed" + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "test", + Prompt: "Compare these images", + Images: []api.ImageData{testImage1, testImage2}, + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + // Verify both images were passed + if len(mock.CompletionRequest.Images) != 2 { + t.Fatalf("expected 2 images in completion request, got %d", len(mock.CompletionRequest.Images)) + } + + if !bytes.Equal(mock.CompletionRequest.Images[0].Data, testImage1) { + t.Errorf("first image data mismatch") + } + + if !bytes.Equal(mock.CompletionRequest.Images[1].Data, testImage2) { + t.Errorf("second image data mismatch") + } + + if mock.CompletionRequest.Images[0].ID != 0 || mock.CompletionRequest.Images[1].ID != 1 { + t.Errorf("expected image IDs 0 and 1, got %d and %d", + mock.CompletionRequest.Images[0].ID, mock.CompletionRequest.Images[1].ID) + } + }) + + t.Run("no images when none provided", func(t *testing.T) { + mock.CompletionResponse.Content = "No images" + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "test", + Prompt: "Hello", + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + // Verify no images in completion request + if len(mock.CompletionRequest.Images) != 0 { + t.Fatalf("expected 0 images in completion request, got %d", len(mock.CompletionRequest.Images)) + } + }) +} diff --git a/x/imagegen/cli.go b/x/imagegen/cli.go index a55a1b016..6c8ea0f54 100644 --- a/x/imagegen/cli.go +++ b/x/imagegen/cli.go @@ -10,7 +10,10 @@ import ( "errors" "fmt" "io" + "net/http" "os" + "regexp" + "slices" "strconv" "strings" "time" @@ -75,6 +78,7 @@ Image Generation Flags (experimental): // RunCLI handles the CLI for image generation models. // Returns true if it handled the request, false if the caller should continue with normal flow. // Supports flags: --width, --height, --steps, --seed, --negative +// Image paths can be included in the prompt and will be extracted automatically. func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, keepAlive *api.Duration) error { // Get options from flags (with env var defaults) opts := DefaultOptions() @@ -111,9 +115,16 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep return err } + // Extract any image paths from the prompt + prompt, images, err := extractFileData(prompt) + if err != nil { + return err + } + req := &api.GenerateRequest{ Model: modelName, Prompt: prompt, + Images: images, Width: int32(opts.Width), Height: int32(opts.Height), Steps: int32(opts.Steps), @@ -254,14 +265,33 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio printCurrentSettings(opts) continue case strings.HasPrefix(line, "/"): - fmt.Fprintf(os.Stderr, "Unknown command: %s (try /help)\n", line) + // Check if it's a file path, not a command + args := strings.Fields(line) + isFile := false + for _, f := range extractFileNames(line) { + if strings.HasPrefix(f, args[0]) { + isFile = true + break + } + } + if !isFile { + fmt.Fprintf(os.Stderr, "Unknown command: %s (try /help)\n", args[0]) + continue + } + } + + // Extract any image paths from the input + prompt, images, err := extractFileData(line) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) continue } // Generate image with current options req := &api.GenerateRequest{ Model: modelName, - Prompt: line, + Prompt: prompt, + Images: images, Width: int32(opts.Width), Height: int32(opts.Height), Steps: int32(opts.Steps), @@ -486,3 +516,59 @@ func displayImageInTerminal(imagePath string) bool { return false } } + +// extractFileNames finds image file paths in the input string. +func extractFileNames(input string) []string { + // Regex to match file paths with image extensions + regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp)\b` + re := regexp.MustCompile(regexPattern) + return re.FindAllString(input, -1) +} + +// extractFileData extracts image data from file paths found in the input. +// Returns the cleaned prompt (with file paths removed) and the image data. +func extractFileData(input string) (string, []api.ImageData, error) { + filePaths := extractFileNames(input) + var imgs []api.ImageData + + for _, fp := range filePaths { + // Normalize escaped spaces + nfp := strings.ReplaceAll(fp, "\\ ", " ") + nfp = strings.ReplaceAll(nfp, "%20", " ") + + data, err := getImageData(nfp) + if errors.Is(err, os.ErrNotExist) { + continue + } else if err != nil { + return "", nil, err + } + fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp) + input = strings.ReplaceAll(input, fp, "") + imgs = append(imgs, data) + } + return strings.TrimSpace(input), imgs, nil +} + +// getImageData reads and validates image data from a file. +func getImageData(filePath string) ([]byte, error) { + file, err := os.Open(filePath) + if err != nil { + return nil, err + } + defer file.Close() + + buf := make([]byte, 512) + _, err = file.Read(buf) + if err != nil { + return nil, err + } + + contentType := http.DetectContentType(buf) + allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp"} + if !slices.Contains(allowedTypes, contentType) { + return nil, fmt.Errorf("invalid image type: %s", contentType) + } + + // Re-read the full file + return os.ReadFile(filePath) +} diff --git a/x/imagegen/models/flux2/flux2.go b/x/imagegen/models/flux2/flux2.go index 348490ba7..23a00e631 100644 --- a/x/imagegen/models/flux2/flux2.go +++ b/x/imagegen/models/flux2/flux2.go @@ -177,6 +177,20 @@ func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height }) } +// GenerateImageWithInputs implements runner.ImageEditModel interface. +// It generates an image conditioned on the provided input images for image editing. +func (m *Model) GenerateImageWithInputs(ctx context.Context, prompt string, width, height int32, steps int, seed int64, inputImages []image.Image, progress func(step, total int)) (*mlx.Array, error) { + return m.GenerateFromConfig(ctx, &GenerateConfig{ + Prompt: prompt, + Width: width, + Height: height, + Steps: steps, + Seed: seed, + InputImages: inputImages, + Progress: progress, + }) +} + // MaxOutputPixels is the maximum output resolution (4 megapixels, ~2048x2048) const MaxOutputPixels = 2048 * 2048 diff --git a/x/imagegen/runner/runner.go b/x/imagegen/runner/runner.go index 8fe5c2de1..f43276468 100644 --- a/x/imagegen/runner/runner.go +++ b/x/imagegen/runner/runner.go @@ -9,6 +9,7 @@ import ( "encoding/json" "flag" "fmt" + "image" "log/slog" "net/http" "os" @@ -25,11 +26,12 @@ import ( // Request is the image generation request format type Request struct { - Prompt string `json:"prompt"` - Width int32 `json:"width,omitempty"` - Height int32 `json:"height,omitempty"` - Steps int `json:"steps,omitempty"` - Seed int64 `json:"seed,omitempty"` + Prompt string `json:"prompt"` + Width int32 `json:"width,omitempty"` + Height int32 `json:"height,omitempty"` + Steps int `json:"steps,omitempty"` + Seed int64 `json:"seed,omitempty"` + Images [][]byte `json:"images,omitempty"` // Input images for image editing/conditioning } // Response is streamed back for each progress update @@ -46,6 +48,13 @@ type ImageModel interface { GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error) } +// ImageEditModel extends ImageModel with image editing/conditioning capability. +// Models that support input images for editing should implement this interface. +type ImageEditModel interface { + ImageModel + GenerateImageWithInputs(ctx context.Context, prompt string, width, height int32, steps int, seed int64, inputImages []image.Image, progress func(step, total int)) (*mlx.Array, error) +} + // Server holds the model and handles requests type Server struct { mu sync.Mutex @@ -153,6 +162,44 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) { return } + // Validate and decode input images + const maxInputImages = 2 + if len(req.Images) > maxInputImages { + http.Error(w, fmt.Sprintf("too many input images, maximum is %d", maxInputImages), http.StatusBadRequest) + return + } + + var inputImages []image.Image + if len(req.Images) > 0 { + // TODO: add memory check for input images + + inputImages = make([]image.Image, len(req.Images)) + for i, imgBytes := range req.Images { + img, err := imagegen.DecodeImage(imgBytes) + if err != nil { + http.Error(w, fmt.Sprintf("invalid image %d: %v", i, err), http.StatusBadRequest) + return + } + inputImages[i] = img + } + slog.Info("decoded input images", "count", len(inputImages)) + + // Default width/height to first input image dimensions, scaled to max 1024 + bounds := inputImages[0].Bounds() + w, h := bounds.Dx(), bounds.Dy() + if w > 1024 || h > 1024 { + if w > h { + h = h * 1024 / w + w = 1024 + } else { + w = w * 1024 / h + h = 1024 + } + } + req.Width = int32(w) + req.Height = int32(h) + } + // Serialize generation requests - MLX model may not handle concurrent generation s.mu.Lock() defer s.mu.Unlock() @@ -184,7 +231,19 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) { flusher.Flush() } - img, err := s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress) + // Use ImageEditModel if available and images provided, otherwise use basic ImageModel + var img *mlx.Array + var err error + if len(inputImages) > 0 { + editModel, ok := s.model.(ImageEditModel) + if !ok { + http.Error(w, "model does not support image editing", http.StatusBadRequest) + return + } + img, err = editModel.GenerateImageWithInputs(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, inputImages, progress) + } else { + img, err = s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress) + } if err != nil { // Don't send error for cancellation diff --git a/x/imagegen/server.go b/x/imagegen/server.go index b645e3065..ae13f5ad7 100644 --- a/x/imagegen/server.go +++ b/x/imagegen/server.go @@ -7,6 +7,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "log/slog" "math/rand" "net" @@ -232,19 +233,27 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f seed = time.Now().UnixNano() } + // Extract raw image bytes from llm.ImageData slice + var images [][]byte + for _, img := range req.Images { + images = append(images, img.Data) + } + // Build request for subprocess creq := struct { - Prompt string `json:"prompt"` - Width int32 `json:"width,omitempty"` - Height int32 `json:"height,omitempty"` - Steps int32 `json:"steps,omitempty"` - Seed int64 `json:"seed,omitempty"` + Prompt string `json:"prompt"` + Width int32 `json:"width,omitempty"` + Height int32 `json:"height,omitempty"` + Steps int32 `json:"steps,omitempty"` + Seed int64 `json:"seed,omitempty"` + Images [][]byte `json:"images,omitempty"` }{ Prompt: req.Prompt, Width: req.Width, Height: req.Height, Steps: req.Steps, Seed: seed, + Images: images, } body, err := json.Marshal(creq) @@ -266,7 +275,8 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return fmt.Errorf("request failed: %d", resp.StatusCode) + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("%s", strings.TrimSpace(string(body))) } scanner := bufio.NewScanner(resp.Body)