mirror of
https://github.com/ollama/ollama.git
synced 2026-01-29 07:12:03 +03:00
x/imagegen: add image edit capabilities (#13846)
This commit is contained in:
@@ -609,3 +609,49 @@ func ImageGenerationsMiddleware() gin.HandlerFunc {
|
|||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ImageEditsMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
var req openai.ImageEditRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Prompt == "" {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "prompt is required"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Model == "" {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Image == "" {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "image is required"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
genReq, err := openai.FromImageEditRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Request.Body = io.NopCloser(&b)
|
||||||
|
|
||||||
|
w := &ImageWriter{
|
||||||
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer = w
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1112,3 +1112,129 @@ func TestImageWriterResponse(t *testing.T) {
|
|||||||
t.Errorf("expected image data 'dGVzdC1pbWFnZS1kYXRh', got %s", imageResp.Data[0].B64JSON)
|
t.Errorf("expected image data 'dGVzdC1pbWFnZS1kYXRh', got %s", imageResp.Data[0].B64JSON)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestImageEditsMiddleware(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
req api.GenerateRequest
|
||||||
|
err openai.ErrorResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
var capturedRequest *api.GenerateRequest
|
||||||
|
|
||||||
|
// Base64-encoded test image (1x1 pixel PNG)
|
||||||
|
testImage := "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII="
|
||||||
|
decodedImage, _ := base64.StdEncoding.DecodeString("iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=")
|
||||||
|
|
||||||
|
testCases := []testCase{
|
||||||
|
{
|
||||||
|
name: "image edit basic",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"prompt": "make it blue",
|
||||||
|
"image": "` + testImage + `"
|
||||||
|
}`,
|
||||||
|
req: api.GenerateRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "make it blue",
|
||||||
|
Images: []api.ImageData{decodedImage},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "image edit with size",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"prompt": "make it blue",
|
||||||
|
"image": "` + testImage + `",
|
||||||
|
"size": "512x768"
|
||||||
|
}`,
|
||||||
|
req: api.GenerateRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "make it blue",
|
||||||
|
Images: []api.ImageData{decodedImage},
|
||||||
|
Width: 512,
|
||||||
|
Height: 768,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "image edit missing prompt",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"image": "` + testImage + `"
|
||||||
|
}`,
|
||||||
|
err: openai.ErrorResponse{
|
||||||
|
Error: openai.Error{
|
||||||
|
Message: "prompt is required",
|
||||||
|
Type: "invalid_request_error",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "image edit missing model",
|
||||||
|
body: `{
|
||||||
|
"prompt": "make it blue",
|
||||||
|
"image": "` + testImage + `"
|
||||||
|
}`,
|
||||||
|
err: openai.ErrorResponse{
|
||||||
|
Error: openai.Error{
|
||||||
|
Message: "model is required",
|
||||||
|
Type: "invalid_request_error",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "image edit missing image",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"prompt": "make it blue"
|
||||||
|
}`,
|
||||||
|
err: openai.ErrorResponse{
|
||||||
|
Error: openai.Error{
|
||||||
|
Message: "image is required",
|
||||||
|
Type: "invalid_request_error",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := func(c *gin.Context) {
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(ImageEditsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||||
|
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
defer func() { capturedRequest = nil }()
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
if tc.err.Error.Message != "" {
|
||||||
|
var errResp openai.ErrorResponse
|
||||||
|
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(tc.err, errResp); diff != "" {
|
||||||
|
t.Fatalf("errors did not match:\n%s", diff)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
|
||||||
|
t.Fatalf("requests did not match:\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -794,3 +794,47 @@ func ToImageGenerationResponse(resp api.GenerateResponse) ImageGenerationRespons
|
|||||||
Data: data,
|
Data: data,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ImageEditRequest is an OpenAI-compatible image edit request.
|
||||||
|
type ImageEditRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
Image string `json:"image"` // Base64-encoded image data
|
||||||
|
Size string `json:"size,omitempty"` // e.g., "1024x1024"
|
||||||
|
Seed *int64 `json:"seed,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromImageEditRequest converts an OpenAI image edit request to an Ollama GenerateRequest.
|
||||||
|
func FromImageEditRequest(r ImageEditRequest) (api.GenerateRequest, error) {
|
||||||
|
req := api.GenerateRequest{
|
||||||
|
Model: r.Model,
|
||||||
|
Prompt: r.Prompt,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode the input image
|
||||||
|
if r.Image != "" {
|
||||||
|
imgData, err := decodeImageURL(r.Image)
|
||||||
|
if err != nil {
|
||||||
|
return api.GenerateRequest{}, fmt.Errorf("invalid image: %w", err)
|
||||||
|
}
|
||||||
|
req.Images = append(req.Images, imgData)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse size if provided (e.g., "1024x768")
|
||||||
|
if r.Size != "" {
|
||||||
|
var w, h int32
|
||||||
|
if _, err := fmt.Sscanf(r.Size, "%dx%d", &w, &h); err == nil {
|
||||||
|
req.Width = w
|
||||||
|
req.Height = h
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Seed != nil {
|
||||||
|
if req.Options == nil {
|
||||||
|
req.Options = map[string]any{}
|
||||||
|
}
|
||||||
|
req.Options["seed"] = *r.Seed
|
||||||
|
}
|
||||||
|
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -448,3 +448,86 @@ func TestFromChatRequest_TopLogprobsRange(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFromImageEditRequest_Basic(t *testing.T) {
|
||||||
|
req := ImageEditRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "make it blue",
|
||||||
|
Image: prefix + image,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := FromImageEditRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Model != "test-model" {
|
||||||
|
t.Errorf("expected model 'test-model', got %q", result.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Prompt != "make it blue" {
|
||||||
|
t.Errorf("expected prompt 'make it blue', got %q", result.Prompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Images) != 1 {
|
||||||
|
t.Fatalf("expected 1 image, got %d", len(result.Images))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromImageEditRequest_WithSize(t *testing.T) {
|
||||||
|
req := ImageEditRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "make it blue",
|
||||||
|
Image: prefix + image,
|
||||||
|
Size: "512x768",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := FromImageEditRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Width != 512 {
|
||||||
|
t.Errorf("expected width 512, got %d", result.Width)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Height != 768 {
|
||||||
|
t.Errorf("expected height 768, got %d", result.Height)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromImageEditRequest_WithSeed(t *testing.T) {
|
||||||
|
seed := int64(12345)
|
||||||
|
req := ImageEditRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "make it blue",
|
||||||
|
Image: prefix + image,
|
||||||
|
Seed: &seed,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := FromImageEditRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Options == nil {
|
||||||
|
t.Fatal("expected options to be set")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Options["seed"] != seed {
|
||||||
|
t.Errorf("expected seed %d, got %v", seed, result.Options["seed"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromImageEditRequest_InvalidImage(t *testing.T) {
|
||||||
|
req := ImageEditRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "make it blue",
|
||||||
|
Image: "not-valid-base64",
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := FromImageEditRequest(req)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for invalid image")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -75,12 +75,6 @@ type Model struct {
|
|||||||
func (m *Model) Capabilities() []model.Capability {
|
func (m *Model) Capabilities() []model.Capability {
|
||||||
capabilities := []model.Capability{}
|
capabilities := []model.Capability{}
|
||||||
|
|
||||||
// Check for image generation model via config capabilities
|
|
||||||
if slices.Contains(m.Config.Capabilities, "image") {
|
|
||||||
return []model.Capability{model.CapabilityImage}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for completion capability
|
|
||||||
if m.ModelPath != "" {
|
if m.ModelPath != "" {
|
||||||
f, err := gguf.Open(m.ModelPath)
|
f, err := gguf.Open(m.ModelPath)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
|||||||
@@ -56,6 +56,15 @@ func TestModelCapabilities(t *testing.T) {
|
|||||||
},
|
},
|
||||||
expectedCaps: []model.Capability{model.CapabilityImage},
|
expectedCaps: []model.Capability{model.CapabilityImage},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "model with image and vision capability (image editing)",
|
||||||
|
model: Model{
|
||||||
|
Config: model.ConfigV2{
|
||||||
|
Capabilities: []string{"image", "vision"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedCaps: []model.Capability{model.CapabilityImage, model.CapabilityVision},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "model with completion capability",
|
name: "model with completion capability",
|
||||||
model: Model{
|
model: Model{
|
||||||
|
|||||||
@@ -1604,8 +1604,9 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
|||||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||||
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||||
// OpenAI-compatible image generation endpoint
|
// OpenAI-compatible image generation endpoints
|
||||||
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
||||||
|
r.POST("/v1/images/edits", middleware.ImageEditsMiddleware(), s.GenerateHandler)
|
||||||
|
|
||||||
// Inference (Anthropic compatibility)
|
// Inference (Anthropic compatibility)
|
||||||
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||||
@@ -2523,6 +2524,11 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var images []llm.ImageData
|
||||||
|
for i, imgData := range req.Images {
|
||||||
|
images = append(images, llm.ImageData{ID: i, Data: imgData})
|
||||||
|
}
|
||||||
|
|
||||||
var streamStarted bool
|
var streamStarted bool
|
||||||
if err := runner.Completion(c.Request.Context(), llm.CompletionRequest{
|
if err := runner.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
Prompt: req.Prompt,
|
Prompt: req.Prompt,
|
||||||
@@ -2530,6 +2536,7 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
|
|||||||
Height: req.Height,
|
Height: req.Height,
|
||||||
Steps: req.Steps,
|
Steps: req.Steps,
|
||||||
Seed: seed,
|
Seed: seed,
|
||||||
|
Images: images,
|
||||||
}, func(cr llm.CompletionResponse) {
|
}, func(cr llm.CompletionResponse) {
|
||||||
streamStarted = true
|
streamStarted = true
|
||||||
res := api.GenerateResponse{
|
res := api.GenerateResponse{
|
||||||
|
|||||||
@@ -2193,3 +2193,157 @@ func TestGenerateUnload(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGenerateWithImages(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
mock := mockRunner{
|
||||||
|
CompletionResponse: llm.CompletionResponse{
|
||||||
|
Done: true,
|
||||||
|
DoneReason: llm.DoneReasonStop,
|
||||||
|
PromptEvalCount: 1,
|
||||||
|
PromptEvalDuration: 1,
|
||||||
|
EvalCount: 1,
|
||||||
|
EvalDuration: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := Server{
|
||||||
|
sched: &Scheduler{
|
||||||
|
pendingReqCh: make(chan *LlmRequest, 1),
|
||||||
|
finishedReqCh: make(chan *LlmRequest, 1),
|
||||||
|
expiredCh: make(chan *runnerRef, 1),
|
||||||
|
unloadedCh: make(chan any, 1),
|
||||||
|
loaded: make(map[string]*runnerRef),
|
||||||
|
newServerFn: newMockServer(&mock),
|
||||||
|
getGpuFn: getGpuFn,
|
||||||
|
getSystemInfoFn: getSystemInfoFn,
|
||||||
|
waitForRecovery: 250 * time.Millisecond,
|
||||||
|
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
req.successCh <- &runnerRef{
|
||||||
|
llama: &mock,
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
go s.sched.Run(t.Context())
|
||||||
|
|
||||||
|
_, digest := createBinFile(t, ggml.KV{
|
||||||
|
"general.architecture": "llama",
|
||||||
|
"llama.block_count": uint32(1),
|
||||||
|
"llama.context_length": uint32(8192),
|
||||||
|
"llama.embedding_length": uint32(4096),
|
||||||
|
"llama.attention.head_count": uint32(32),
|
||||||
|
"llama.attention.head_count_kv": uint32(8),
|
||||||
|
"tokenizer.ggml.tokens": []string{""},
|
||||||
|
"tokenizer.ggml.scores": []float32{0},
|
||||||
|
"tokenizer.ggml.token_type": []int32{0},
|
||||||
|
}, []*ggml.Tensor{
|
||||||
|
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
})
|
||||||
|
|
||||||
|
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||||
|
Model: "test",
|
||||||
|
Files: map[string]string{"file.gguf": digest},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("images passed to completion request", func(t *testing.T) {
|
||||||
|
testImage := []byte("test-image-data")
|
||||||
|
|
||||||
|
mock.CompletionResponse.Content = "Image processed"
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test",
|
||||||
|
Prompt: "Describe this image",
|
||||||
|
Images: []api.ImageData{testImage},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify images were passed to the completion request
|
||||||
|
if len(mock.CompletionRequest.Images) != 1 {
|
||||||
|
t.Fatalf("expected 1 image in completion request, got %d", len(mock.CompletionRequest.Images))
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(mock.CompletionRequest.Images[0].Data, testImage) {
|
||||||
|
t.Errorf("image data mismatch in completion request")
|
||||||
|
}
|
||||||
|
|
||||||
|
if mock.CompletionRequest.Images[0].ID != 0 {
|
||||||
|
t.Errorf("expected image ID 0, got %d", mock.CompletionRequest.Images[0].ID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple images passed to completion request", func(t *testing.T) {
|
||||||
|
testImage1 := []byte("test-image-1")
|
||||||
|
testImage2 := []byte("test-image-2")
|
||||||
|
|
||||||
|
mock.CompletionResponse.Content = "Images processed"
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test",
|
||||||
|
Prompt: "Compare these images",
|
||||||
|
Images: []api.ImageData{testImage1, testImage2},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify both images were passed
|
||||||
|
if len(mock.CompletionRequest.Images) != 2 {
|
||||||
|
t.Fatalf("expected 2 images in completion request, got %d", len(mock.CompletionRequest.Images))
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(mock.CompletionRequest.Images[0].Data, testImage1) {
|
||||||
|
t.Errorf("first image data mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(mock.CompletionRequest.Images[1].Data, testImage2) {
|
||||||
|
t.Errorf("second image data mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
if mock.CompletionRequest.Images[0].ID != 0 || mock.CompletionRequest.Images[1].ID != 1 {
|
||||||
|
t.Errorf("expected image IDs 0 and 1, got %d and %d",
|
||||||
|
mock.CompletionRequest.Images[0].ID, mock.CompletionRequest.Images[1].ID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no images when none provided", func(t *testing.T) {
|
||||||
|
mock.CompletionResponse.Content = "No images"
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test",
|
||||||
|
Prompt: "Hello",
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify no images in completion request
|
||||||
|
if len(mock.CompletionRequest.Images) != 0 {
|
||||||
|
t.Fatalf("expected 0 images in completion request, got %d", len(mock.CompletionRequest.Images))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -10,7 +10,10 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"regexp"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -75,6 +78,7 @@ Image Generation Flags (experimental):
|
|||||||
// RunCLI handles the CLI for image generation models.
|
// RunCLI handles the CLI for image generation models.
|
||||||
// Returns true if it handled the request, false if the caller should continue with normal flow.
|
// Returns true if it handled the request, false if the caller should continue with normal flow.
|
||||||
// Supports flags: --width, --height, --steps, --seed, --negative
|
// Supports flags: --width, --height, --steps, --seed, --negative
|
||||||
|
// Image paths can be included in the prompt and will be extracted automatically.
|
||||||
func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, keepAlive *api.Duration) error {
|
func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, keepAlive *api.Duration) error {
|
||||||
// Get options from flags (with env var defaults)
|
// Get options from flags (with env var defaults)
|
||||||
opts := DefaultOptions()
|
opts := DefaultOptions()
|
||||||
@@ -111,9 +115,16 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extract any image paths from the prompt
|
||||||
|
prompt, images, err := extractFileData(prompt)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
req := &api.GenerateRequest{
|
req := &api.GenerateRequest{
|
||||||
Model: modelName,
|
Model: modelName,
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
|
Images: images,
|
||||||
Width: int32(opts.Width),
|
Width: int32(opts.Width),
|
||||||
Height: int32(opts.Height),
|
Height: int32(opts.Height),
|
||||||
Steps: int32(opts.Steps),
|
Steps: int32(opts.Steps),
|
||||||
@@ -254,14 +265,33 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
|
|||||||
printCurrentSettings(opts)
|
printCurrentSettings(opts)
|
||||||
continue
|
continue
|
||||||
case strings.HasPrefix(line, "/"):
|
case strings.HasPrefix(line, "/"):
|
||||||
fmt.Fprintf(os.Stderr, "Unknown command: %s (try /help)\n", line)
|
// Check if it's a file path, not a command
|
||||||
|
args := strings.Fields(line)
|
||||||
|
isFile := false
|
||||||
|
for _, f := range extractFileNames(line) {
|
||||||
|
if strings.HasPrefix(f, args[0]) {
|
||||||
|
isFile = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !isFile {
|
||||||
|
fmt.Fprintf(os.Stderr, "Unknown command: %s (try /help)\n", args[0])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract any image paths from the input
|
||||||
|
prompt, images, err := extractFileData(line)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate image with current options
|
// Generate image with current options
|
||||||
req := &api.GenerateRequest{
|
req := &api.GenerateRequest{
|
||||||
Model: modelName,
|
Model: modelName,
|
||||||
Prompt: line,
|
Prompt: prompt,
|
||||||
|
Images: images,
|
||||||
Width: int32(opts.Width),
|
Width: int32(opts.Width),
|
||||||
Height: int32(opts.Height),
|
Height: int32(opts.Height),
|
||||||
Steps: int32(opts.Steps),
|
Steps: int32(opts.Steps),
|
||||||
@@ -486,3 +516,59 @@ func displayImageInTerminal(imagePath string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractFileNames finds image file paths in the input string.
|
||||||
|
func extractFileNames(input string) []string {
|
||||||
|
// Regex to match file paths with image extensions
|
||||||
|
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp)\b`
|
||||||
|
re := regexp.MustCompile(regexPattern)
|
||||||
|
return re.FindAllString(input, -1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractFileData extracts image data from file paths found in the input.
|
||||||
|
// Returns the cleaned prompt (with file paths removed) and the image data.
|
||||||
|
func extractFileData(input string) (string, []api.ImageData, error) {
|
||||||
|
filePaths := extractFileNames(input)
|
||||||
|
var imgs []api.ImageData
|
||||||
|
|
||||||
|
for _, fp := range filePaths {
|
||||||
|
// Normalize escaped spaces
|
||||||
|
nfp := strings.ReplaceAll(fp, "\\ ", " ")
|
||||||
|
nfp = strings.ReplaceAll(nfp, "%20", " ")
|
||||||
|
|
||||||
|
data, err := getImageData(nfp)
|
||||||
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
|
continue
|
||||||
|
} else if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
|
||||||
|
input = strings.ReplaceAll(input, fp, "")
|
||||||
|
imgs = append(imgs, data)
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(input), imgs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getImageData reads and validates image data from a file.
|
||||||
|
func getImageData(filePath string) ([]byte, error) {
|
||||||
|
file, err := os.Open(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
buf := make([]byte, 512)
|
||||||
|
_, err = file.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
contentType := http.DetectContentType(buf)
|
||||||
|
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp"}
|
||||||
|
if !slices.Contains(allowedTypes, contentType) {
|
||||||
|
return nil, fmt.Errorf("invalid image type: %s", contentType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-read the full file
|
||||||
|
return os.ReadFile(filePath)
|
||||||
|
}
|
||||||
|
|||||||
@@ -177,6 +177,20 @@ func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GenerateImageWithInputs implements runner.ImageEditModel interface.
|
||||||
|
// It generates an image conditioned on the provided input images for image editing.
|
||||||
|
func (m *Model) GenerateImageWithInputs(ctx context.Context, prompt string, width, height int32, steps int, seed int64, inputImages []image.Image, progress func(step, total int)) (*mlx.Array, error) {
|
||||||
|
return m.GenerateFromConfig(ctx, &GenerateConfig{
|
||||||
|
Prompt: prompt,
|
||||||
|
Width: width,
|
||||||
|
Height: height,
|
||||||
|
Steps: steps,
|
||||||
|
Seed: seed,
|
||||||
|
InputImages: inputImages,
|
||||||
|
Progress: progress,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// MaxOutputPixels is the maximum output resolution (4 megapixels, ~2048x2048)
|
// MaxOutputPixels is the maximum output resolution (4 megapixels, ~2048x2048)
|
||||||
const MaxOutputPixels = 2048 * 2048
|
const MaxOutputPixels = 2048 * 2048
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"image"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
@@ -25,11 +26,12 @@ import (
|
|||||||
|
|
||||||
// Request is the image generation request format
|
// Request is the image generation request format
|
||||||
type Request struct {
|
type Request struct {
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
Width int32 `json:"width,omitempty"`
|
Width int32 `json:"width,omitempty"`
|
||||||
Height int32 `json:"height,omitempty"`
|
Height int32 `json:"height,omitempty"`
|
||||||
Steps int `json:"steps,omitempty"`
|
Steps int `json:"steps,omitempty"`
|
||||||
Seed int64 `json:"seed,omitempty"`
|
Seed int64 `json:"seed,omitempty"`
|
||||||
|
Images [][]byte `json:"images,omitempty"` // Input images for image editing/conditioning
|
||||||
}
|
}
|
||||||
|
|
||||||
// Response is streamed back for each progress update
|
// Response is streamed back for each progress update
|
||||||
@@ -46,6 +48,13 @@ type ImageModel interface {
|
|||||||
GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error)
|
GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ImageEditModel extends ImageModel with image editing/conditioning capability.
|
||||||
|
// Models that support input images for editing should implement this interface.
|
||||||
|
type ImageEditModel interface {
|
||||||
|
ImageModel
|
||||||
|
GenerateImageWithInputs(ctx context.Context, prompt string, width, height int32, steps int, seed int64, inputImages []image.Image, progress func(step, total int)) (*mlx.Array, error)
|
||||||
|
}
|
||||||
|
|
||||||
// Server holds the model and handles requests
|
// Server holds the model and handles requests
|
||||||
type Server struct {
|
type Server struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
@@ -153,6 +162,44 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate and decode input images
|
||||||
|
const maxInputImages = 2
|
||||||
|
if len(req.Images) > maxInputImages {
|
||||||
|
http.Error(w, fmt.Sprintf("too many input images, maximum is %d", maxInputImages), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var inputImages []image.Image
|
||||||
|
if len(req.Images) > 0 {
|
||||||
|
// TODO: add memory check for input images
|
||||||
|
|
||||||
|
inputImages = make([]image.Image, len(req.Images))
|
||||||
|
for i, imgBytes := range req.Images {
|
||||||
|
img, err := imagegen.DecodeImage(imgBytes)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("invalid image %d: %v", i, err), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
inputImages[i] = img
|
||||||
|
}
|
||||||
|
slog.Info("decoded input images", "count", len(inputImages))
|
||||||
|
|
||||||
|
// Default width/height to first input image dimensions, scaled to max 1024
|
||||||
|
bounds := inputImages[0].Bounds()
|
||||||
|
w, h := bounds.Dx(), bounds.Dy()
|
||||||
|
if w > 1024 || h > 1024 {
|
||||||
|
if w > h {
|
||||||
|
h = h * 1024 / w
|
||||||
|
w = 1024
|
||||||
|
} else {
|
||||||
|
w = w * 1024 / h
|
||||||
|
h = 1024
|
||||||
|
}
|
||||||
|
}
|
||||||
|
req.Width = int32(w)
|
||||||
|
req.Height = int32(h)
|
||||||
|
}
|
||||||
|
|
||||||
// Serialize generation requests - MLX model may not handle concurrent generation
|
// Serialize generation requests - MLX model may not handle concurrent generation
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
@@ -184,7 +231,19 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
img, err := s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress)
|
// Use ImageEditModel if available and images provided, otherwise use basic ImageModel
|
||||||
|
var img *mlx.Array
|
||||||
|
var err error
|
||||||
|
if len(inputImages) > 0 {
|
||||||
|
editModel, ok := s.model.(ImageEditModel)
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "model does not support image editing", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
img, err = editModel.GenerateImageWithInputs(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, inputImages, progress)
|
||||||
|
} else {
|
||||||
|
img, err = s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress)
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Don't send error for cancellation
|
// Don't send error for cancellation
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
@@ -232,19 +233,27 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
|||||||
seed = time.Now().UnixNano()
|
seed = time.Now().UnixNano()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extract raw image bytes from llm.ImageData slice
|
||||||
|
var images [][]byte
|
||||||
|
for _, img := range req.Images {
|
||||||
|
images = append(images, img.Data)
|
||||||
|
}
|
||||||
|
|
||||||
// Build request for subprocess
|
// Build request for subprocess
|
||||||
creq := struct {
|
creq := struct {
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
Width int32 `json:"width,omitempty"`
|
Width int32 `json:"width,omitempty"`
|
||||||
Height int32 `json:"height,omitempty"`
|
Height int32 `json:"height,omitempty"`
|
||||||
Steps int32 `json:"steps,omitempty"`
|
Steps int32 `json:"steps,omitempty"`
|
||||||
Seed int64 `json:"seed,omitempty"`
|
Seed int64 `json:"seed,omitempty"`
|
||||||
|
Images [][]byte `json:"images,omitempty"`
|
||||||
}{
|
}{
|
||||||
Prompt: req.Prompt,
|
Prompt: req.Prompt,
|
||||||
Width: req.Width,
|
Width: req.Width,
|
||||||
Height: req.Height,
|
Height: req.Height,
|
||||||
Steps: req.Steps,
|
Steps: req.Steps,
|
||||||
Seed: seed,
|
Seed: seed,
|
||||||
|
Images: images,
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := json.Marshal(creq)
|
body, err := json.Marshal(creq)
|
||||||
@@ -266,7 +275,8 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
|||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return fmt.Errorf("request failed: %d", resp.StatusCode)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
return fmt.Errorf("%s", strings.TrimSpace(string(body)))
|
||||||
}
|
}
|
||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
|||||||
Reference in New Issue
Block a user