diff --git a/api/types.go b/api/types.go index 2434fe478..09961d994 100644 --- a/api/types.go +++ b/api/types.go @@ -127,6 +127,20 @@ type GenerateRequest struct { // each with an associated log probability. Only applies when Logprobs is true. // Valid values are 0-20. Default is 0 (only return the selected token's logprob). TopLogprobs int `json:"top_logprobs,omitempty"` + + // Experimental: Image generation fields (may change or be removed) + + // Width is the width of the generated image in pixels. + // Only used for image generation models. + Width int32 `json:"width,omitempty"` + + // Height is the height of the generated image in pixels. + // Only used for image generation models. + Height int32 `json:"height,omitempty"` + + // Steps is the number of diffusion steps for image generation. + // Only used for image generation models. + Steps int32 `json:"steps,omitempty"` } // ChatRequest describes a request sent by [Client.Chat]. @@ -860,6 +874,20 @@ type GenerateResponse struct { // Logprobs contains log probability information for the generated tokens, // if requested via the Logprobs parameter. Logprobs []Logprob `json:"logprobs,omitempty"` + + // Experimental: Image generation fields (may change or be removed) + + // Image contains a base64-encoded generated image. + // Only present for image generation models. + Image string `json:"image,omitempty"` + + // Completed is the number of completed steps in image generation. + // Only present for image generation models during streaming. + Completed int64 `json:"completed,omitempty"` + + // Total is the total number of steps for image generation. + // Only present for image generation models during streaming. + Total int64 `json:"total,omitempty"` } // ModelDetails provides details about a model. diff --git a/cmd/cmd.go b/cmd/cmd.go index a5d46f90d..5139c05cb 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -600,7 +600,7 @@ func RunHandler(cmd *cobra.Command, args []string) error { } // Check if this is an image generation model - if slices.Contains(info.Capabilities, model.CapabilityImageGeneration) { + if slices.Contains(info.Capabilities, model.CapabilityImage) { if opts.Prompt == "" && !interactive { return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"") } @@ -1985,6 +1985,7 @@ func NewCLI() *cobra.Command { } { switch cmd { case runCmd: + imagegen.AppendFlagsDocs(cmd) appendEnvDocs(cmd, []envconfig.EnvVar{envVars["OLLAMA_HOST"], envVars["OLLAMA_NOHISTORY"]}) case serveCmd: appendEnvDocs(cmd, []envconfig.EnvVar{ diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index 4f651135b..eedd0c61d 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -1555,7 +1555,7 @@ func TestShowInfoImageGen(t *testing.T) { ParameterSize: "10.3B", QuantizationLevel: "FP8", }, - Capabilities: []model.Capability{model.CapabilityImageGeneration}, + Capabilities: []model.Capability{model.CapabilityImage}, Requires: "0.14.0", }, false, &b) if err != nil { diff --git a/docs/api.md b/docs/api.md index 7c32c9597..150479e6a 100644 --- a/docs/api.md +++ b/docs/api.md @@ -16,6 +16,7 @@ - [Generate Embeddings](#generate-embeddings) - [List Running Models](#list-running-models) - [Version](#version) +- [Experimental: Image Generation](#image-generation-experimental) ## Conventions @@ -58,6 +59,15 @@ Advanced parameters (optional): - `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`) - `context` (deprecated): the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory +Experimental image generation parameters (for image generation models only): + +> [!WARNING] +> These parameters are experimental and may change in future versions. + +- `width`: width of the generated image in pixels +- `height`: height of the generated image in pixels +- `steps`: number of diffusion steps + #### Structured outputs Structured outputs are supported by providing a JSON schema in the `format` parameter. The model will generate a response that matches the schema. See the [structured outputs](#request-structured-outputs) example below. @@ -1867,3 +1877,55 @@ curl http://localhost:11434/api/version "version": "0.5.1" } ``` + +## Experimental Features + +### Image Generation (Experimental) + +> [!WARNING] +> Image generation is experimental and may change in future versions. + +Image generation is now supported through the standard `/api/generate` endpoint when using image generation models. The API automatically detects when an image generation model is being used. + +See the [Generate a completion](#generate-a-completion) section for the full API documentation. The experimental image generation parameters (`width`, `height`, `steps`) are documented there. + +#### Example + +##### Request + +```shell +curl http://localhost:11434/api/generate -d '{ + "model": "x/z-image-turbo", + "prompt": "a sunset over mountains", + "width": 1024, + "height": 768 +}' +``` + +##### Response (streaming) + +Progress updates during generation: + +```json +{ + "model": "x/z-image-turbo", + "created_at": "2024-01-15T10:30:00.000000Z", + "completed": 5, + "total": 20, + "done": false +} +``` + +##### Final Response + +```json +{ + "model": "x/z-image-turbo", + "created_at": "2024-01-15T10:30:15.000000Z", + "image": "iVBORw0KGgoAAAANSUhEUg...", + "done": true, + "done_reason": "stop", + "total_duration": 15000000000, + "load_duration": 2000000000 +} +``` diff --git a/docs/api/openai-compatibility.mdx b/docs/api/openai-compatibility.mdx index a0882053e..27d3951d5 100644 --- a/docs/api/openai-compatibility.mdx +++ b/docs/api/openai-compatibility.mdx @@ -275,6 +275,73 @@ curl -X POST http://localhost:11434/v1/chat/completions \ - [x] `dimensions` - [ ] `user` +### `/v1/images/generations` (experimental) + +> Note: This endpoint is experimental and may change or be removed in future versions. + +Generate images using image generation models. + + + +```python images.py +from openai import OpenAI + +client = OpenAI( + base_url='http://localhost:11434/v1/', + api_key='ollama', # required but ignored +) + +response = client.images.generate( + model='x/z-image-turbo', + prompt='A cute robot learning to paint', + size='1024x1024', + response_format='b64_json', +) +print(response.data[0].b64_json[:50] + '...') +``` + +```javascript images.js +import OpenAI from "openai"; + +const openai = new OpenAI({ + baseURL: "http://localhost:11434/v1/", + apiKey: "ollama", // required but ignored +}); + +const response = await openai.images.generate({ + model: "x/z-image-turbo", + prompt: "A cute robot learning to paint", + size: "1024x1024", + response_format: "b64_json", +}); + +console.log(response.data[0].b64_json.slice(0, 50) + "..."); +``` + +```shell images.sh +curl -X POST http://localhost:11434/v1/images/generations \ +-H "Content-Type: application/json" \ +-d '{ + "model": "x/z-image-turbo", + "prompt": "A cute robot learning to paint", + "size": "1024x1024", + "response_format": "b64_json" +}' +``` + + + +#### Supported request fields + +- [x] `model` +- [x] `prompt` +- [x] `size` (e.g. "1024x1024") +- [x] `response_format` (only `b64_json` supported) +- [ ] `n` +- [ ] `quality` +- [ ] `style` +- [ ] `user` + ### `/v1/responses` > Note: Added in Ollama v0.13.3 diff --git a/llm/server.go b/llm/server.go index b8ebd40a5..8fedc8468 100644 --- a/llm/server.go +++ b/llm/server.go @@ -1468,6 +1468,7 @@ type CompletionRequest struct { // Image generation fields Width int32 `json:"width,omitempty"` Height int32 `json:"height,omitempty"` + Steps int32 `json:"steps,omitempty"` Seed int64 `json:"seed,omitempty"` } @@ -1518,10 +1519,14 @@ type CompletionResponse struct { // Logprobs contains log probability information if requested Logprobs []Logprob `json:"logprobs,omitempty"` - // Image generation fields - Image []byte `json:"image,omitempty"` // Generated image - Step int `json:"step,omitempty"` // Current generation step - Total int `json:"total,omitempty"` // Total generation steps + // Image contains base64-encoded image data for image generation + Image string `json:"image,omitempty"` + + // Step is the current step in image generation + Step int `json:"step,omitempty"` + + // TotalSteps is the total number of steps for image generation + TotalSteps int `json:"total_steps,omitempty"` } func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { diff --git a/middleware/openai.go b/middleware/openai.go index 64dc97ec1..beaa9ee97 100644 --- a/middleware/openai.go +++ b/middleware/openai.go @@ -546,3 +546,66 @@ func ResponsesMiddleware() gin.HandlerFunc { c.Next() } } + +type ImageWriter struct { + BaseWriter +} + +func (w *ImageWriter) writeResponse(data []byte) (int, error) { + var generateResponse api.GenerateResponse + if err := json.Unmarshal(data, &generateResponse); err != nil { + return 0, err + } + + // Only write response when done with image + if generateResponse.Done && generateResponse.Image != "" { + w.ResponseWriter.Header().Set("Content-Type", "application/json") + return len(data), json.NewEncoder(w.ResponseWriter).Encode(openai.ToImageGenerationResponse(generateResponse)) + } + + return len(data), nil +} + +func (w *ImageWriter) Write(data []byte) (int, error) { + code := w.ResponseWriter.Status() + if code != http.StatusOK { + return w.writeError(data) + } + + return w.writeResponse(data) +} + +func ImageGenerationsMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + var req openai.ImageGenerationRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error())) + return + } + + if req.Prompt == "" { + c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "prompt is required")) + return + } + + if req.Model == "" { + c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required")) + return + } + + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(openai.FromImageGenerationRequest(req)); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error())) + return + } + + c.Request.Body = io.NopCloser(&b) + + w := &ImageWriter{ + BaseWriter: BaseWriter{ResponseWriter: c.Writer}, + } + + c.Writer = w + c.Next() + } +} diff --git a/middleware/openai_test.go b/middleware/openai_test.go index 3b8f5088a..cc7c3c215 100644 --- a/middleware/openai_test.go +++ b/middleware/openai_test.go @@ -961,3 +961,154 @@ func TestRetrieveMiddleware(t *testing.T) { } } } + +func TestImageGenerationsMiddleware(t *testing.T) { + type testCase struct { + name string + body string + req api.GenerateRequest + err openai.ErrorResponse + } + + var capturedRequest *api.GenerateRequest + + testCases := []testCase{ + { + name: "image generation basic", + body: `{ + "model": "test-model", + "prompt": "a beautiful sunset" + }`, + req: api.GenerateRequest{ + Model: "test-model", + Prompt: "a beautiful sunset", + }, + }, + { + name: "image generation with size", + body: `{ + "model": "test-model", + "prompt": "a beautiful sunset", + "size": "512x768" + }`, + req: api.GenerateRequest{ + Model: "test-model", + Prompt: "a beautiful sunset", + Width: 512, + Height: 768, + }, + }, + { + name: "image generation missing prompt", + body: `{ + "model": "test-model" + }`, + err: openai.ErrorResponse{ + Error: openai.Error{ + Message: "prompt is required", + Type: "invalid_request_error", + }, + }, + }, + { + name: "image generation missing model", + body: `{ + "prompt": "a beautiful sunset" + }`, + err: openai.ErrorResponse{ + Error: openai.Error{ + Message: "model is required", + Type: "invalid_request_error", + }, + }, + }, + } + + endpoint := func(c *gin.Context) { + c.Status(http.StatusOK) + } + + gin.SetMode(gin.TestMode) + router := gin.New() + router.Use(ImageGenerationsMiddleware(), captureRequestMiddleware(&capturedRequest)) + router.Handle(http.MethodPost, "/api/generate", endpoint) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body)) + req.Header.Set("Content-Type", "application/json") + + defer func() { capturedRequest = nil }() + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if tc.err.Error.Message != "" { + var errResp openai.ErrorResponse + if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(tc.err, errResp); diff != "" { + t.Fatalf("errors did not match:\n%s", diff) + } + return + } + + if resp.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String()) + } + + if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" { + t.Fatalf("requests did not match:\n%s", diff) + } + }) + } +} + +func TestImageWriterResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Test that ImageWriter transforms GenerateResponse to OpenAI format + endpoint := func(c *gin.Context) { + resp := api.GenerateResponse{ + Model: "test-model", + CreatedAt: time.Unix(1234567890, 0).UTC(), + Done: true, + Image: "dGVzdC1pbWFnZS1kYXRh", // base64 of "test-image-data" + } + data, _ := json.Marshal(resp) + c.Writer.Write(append(data, '\n')) + } + + router := gin.New() + router.Use(ImageGenerationsMiddleware()) + router.Handle(http.MethodPost, "/api/generate", endpoint) + + body := `{"model": "test-model", "prompt": "test"}` + req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String()) + } + + var imageResp openai.ImageGenerationResponse + if err := json.Unmarshal(resp.Body.Bytes(), &imageResp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if imageResp.Created != 1234567890 { + t.Errorf("expected created 1234567890, got %d", imageResp.Created) + } + + if len(imageResp.Data) != 1 { + t.Fatalf("expected 1 image, got %d", len(imageResp.Data)) + } + + if imageResp.Data[0].B64JSON != "dGVzdC1pbWFnZS1kYXRh" { + t.Errorf("expected image data 'dGVzdC1pbWFnZS1kYXRh', got %s", imageResp.Data[0].B64JSON) + } +} diff --git a/openai/openai.go b/openai/openai.go index 44ffb21bc..d1f75c4aa 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -737,3 +737,60 @@ func FromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) { DebugRenderOnly: r.DebugRenderOnly, }, nil } + +// ImageGenerationRequest is an OpenAI-compatible image generation request. +type ImageGenerationRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + Seed *int64 `json:"seed,omitempty"` +} + +// ImageGenerationResponse is an OpenAI-compatible image generation response. +type ImageGenerationResponse struct { + Created int64 `json:"created"` + Data []ImageURLOrData `json:"data"` +} + +// ImageURLOrData contains either a URL or base64-encoded image data. +type ImageURLOrData struct { + URL string `json:"url,omitempty"` + B64JSON string `json:"b64_json,omitempty"` +} + +// FromImageGenerationRequest converts an OpenAI image generation request to an Ollama GenerateRequest. +func FromImageGenerationRequest(r ImageGenerationRequest) api.GenerateRequest { + req := api.GenerateRequest{ + Model: r.Model, + Prompt: r.Prompt, + } + // Parse size if provided (e.g., "1024x768") + if r.Size != "" { + var w, h int32 + if _, err := fmt.Sscanf(r.Size, "%dx%d", &w, &h); err == nil { + req.Width = w + req.Height = h + } + } + if r.Seed != nil { + if req.Options == nil { + req.Options = map[string]any{} + } + req.Options["seed"] = *r.Seed + } + return req +} + +// ToImageGenerationResponse converts an Ollama GenerateResponse to an OpenAI ImageGenerationResponse. +func ToImageGenerationResponse(resp api.GenerateResponse) ImageGenerationResponse { + var data []ImageURLOrData + if resp.Image != "" { + data = []ImageURLOrData{{B64JSON: resp.Image}} + } + return ImageGenerationResponse{ + Created: resp.CreatedAt.Unix(), + Data: data, + } +} diff --git a/server/images.go b/server/images.go index 22a635086..de795b20c 100644 --- a/server/images.go +++ b/server/images.go @@ -41,6 +41,7 @@ var ( errCapabilityVision = errors.New("vision") errCapabilityEmbedding = errors.New("embedding") errCapabilityThinking = errors.New("thinking") + errCapabilityImage = errors.New("image generation") errInsecureProtocol = errors.New("insecure protocol http") ) @@ -76,7 +77,7 @@ func (m *Model) Capabilities() []model.Capability { // Check for image generation model via config capabilities if slices.Contains(m.Config.Capabilities, "image") { - return []model.Capability{model.CapabilityImageGeneration} + return []model.Capability{model.CapabilityImage} } // Check for completion capability @@ -159,6 +160,7 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error { model.CapabilityVision: errCapabilityVision, model.CapabilityEmbedding: errCapabilityEmbedding, model.CapabilityThinking: errCapabilityThinking, + model.CapabilityImage: errCapabilityImage, } for _, cap := range want { diff --git a/server/images_test.go b/server/images_test.go index 156914a07..9e581c8c3 100644 --- a/server/images_test.go +++ b/server/images_test.go @@ -54,7 +54,7 @@ func TestModelCapabilities(t *testing.T) { Capabilities: []string{"image"}, }, }, - expectedCaps: []model.Capability{model.CapabilityImageGeneration}, + expectedCaps: []model.Capability{model.CapabilityImage}, }, { name: "model with completion capability", @@ -242,6 +242,24 @@ func TestModelCheckCapabilities(t *testing.T) { checkCaps: []model.Capability{"unknown"}, expectedErrMsg: "unknown capability", }, + { + name: "model missing image generation capability", + model: Model{ + ModelPath: completionModelPath, + Template: chatTemplate, + }, + checkCaps: []model.Capability{model.CapabilityImage}, + expectedErrMsg: "does not support image generation", + }, + { + name: "model with image generation capability", + model: Model{ + Config: model.ConfigV2{ + Capabilities: []string{"image"}, + }, + }, + checkCaps: []model.Capability{model.CapabilityImage}, + }, } for _, tt := range tests { diff --git a/server/routes.go b/server/routes.go index e90b885e0..5029046b6 100644 --- a/server/routes.go +++ b/server/routes.go @@ -220,6 +220,12 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } + // Handle image generation models + if slices.Contains(m.Capabilities(), model.CapabilityImage) { + s.handleImageGenerate(c, req, name.String(), checkpointStart) + return + } + if req.TopLogprobs < 0 || req.TopLogprobs > 20 { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"}) return @@ -1096,7 +1102,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { } // For image generation models, populate details from imagegen package - if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) { + if slices.Contains(m.Capabilities(), model.CapabilityImage) { if info, err := imagegen.GetModelInfo(name.String()); err == nil { modelDetails.Family = info.Architecture modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount)) @@ -1202,7 +1208,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { return resp, nil } - if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) { + if slices.Contains(m.Capabilities(), model.CapabilityImage) { // Populate tensor info if verbose if req.Verbose { if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil { @@ -1594,8 +1600,8 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler) r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler) r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler) - // Experimental OpenAI-compatible image generation endpoint - r.POST("/v1/images/generations", s.handleImageGeneration) + // OpenAI-compatible image generation endpoint + r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler) // Inference (Anthropic compatibility) r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler) @@ -1917,62 +1923,6 @@ func toolCallId() string { return "call_" + strings.ToLower(string(b)) } -func (s *Server) handleImageGeneration(c *gin.Context) { - var req struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - Size string `json:"size"` - } - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - m, err := GetModel(req.Model) - if err != nil { - c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) - return - } - - runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, api.Options{}, nil) - var runner *runnerRef - select { - case runner = <-runnerCh: - case err := <-errCh: - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - // Parse size (e.g., "1024x768") into width and height - width, height := int32(1024), int32(1024) - if req.Size != "" { - if _, err := fmt.Sscanf(req.Size, "%dx%d", &width, &height); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid size format, expected WxH"}) - return - } - } - - var image []byte - err = runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{ - Prompt: req.Prompt, - Width: width, - Height: height, - }, func(resp llm.CompletionResponse) { - if len(resp.Image) > 0 { - image = resp.Image - } - }) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, gin.H{ - "created": time.Now().Unix(), - "data": []gin.H{{"b64_json": base64.StdEncoding.EncodeToString(image)}}, - }) -} - func (s *Server) ChatHandler(c *gin.Context) { checkpointStart := time.Now() @@ -2522,3 +2472,91 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message { } return msgs } + +// handleImageGenerate handles image generation requests within GenerateHandler. +// This is called when the model has the Image capability. +func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, modelName string, checkpointStart time.Time) { + // Validate image dimensions + const maxDimension int32 = 4096 + if req.Width > maxDimension || req.Height > maxDimension { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("width and height must be <= %d", maxDimension)}) + return + } + + // Schedule the runner for image generation + runner, _, _, err := s.scheduleRunner(c.Request.Context(), modelName, []model.Capability{model.CapabilityImage}, nil, req.KeepAlive) + if err != nil { + handleScheduleError(c, req.Model, err) + return + } + + checkpointLoaded := time.Now() + + // Handle load-only request (empty prompt) + if req.Prompt == "" { + c.JSON(http.StatusOK, api.GenerateResponse{ + Model: req.Model, + CreatedAt: time.Now().UTC(), + Done: true, + DoneReason: "load", + }) + return + } + + // Set headers for streaming response + c.Header("Content-Type", "application/x-ndjson") + + // Get seed from options if provided + var seed int64 + if s, ok := req.Options["seed"]; ok { + switch v := s.(type) { + case int: + seed = int64(v) + case int64: + seed = v + case float64: + seed = int64(v) + } + } + + var streamStarted bool + if err := runner.Completion(c.Request.Context(), llm.CompletionRequest{ + Prompt: req.Prompt, + Width: req.Width, + Height: req.Height, + Steps: req.Steps, + Seed: seed, + }, func(cr llm.CompletionResponse) { + streamStarted = true + res := api.GenerateResponse{ + Model: req.Model, + CreatedAt: time.Now().UTC(), + Done: cr.Done, + } + + if cr.TotalSteps > 0 { + res.Completed = int64(cr.Step) + res.Total = int64(cr.TotalSteps) + } + + if cr.Image != "" { + res.Image = cr.Image + } + + if cr.Done { + res.DoneReason = cr.DoneReason.String() + res.Metrics.TotalDuration = time.Since(checkpointStart) + res.Metrics.LoadDuration = checkpointLoaded.Sub(checkpointStart) + } + + data, _ := json.Marshal(res) + c.Writer.Write(append(data, '\n')) + c.Writer.Flush() + }); err != nil { + // Only send JSON error if streaming hasn't started yet + // (once streaming starts, headers are committed and we can't change status code) + if !streamStarted { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + } + } +} diff --git a/server/sched.go b/server/sched.go index df4fb2a2b..d83a70b36 100644 --- a/server/sched.go +++ b/server/sched.go @@ -571,10 +571,10 @@ func (s *Scheduler) loadImageGen(req *LlmRequest) bool { model: req.model, modelPath: req.model.ModelPath, llama: server, - Options: &req.opts, loading: false, sessionDuration: sessionDuration, - refCount: 1, + totalSize: server.TotalSize(), + vramSize: server.VRAMSize(), } s.loadedMu.Lock() diff --git a/types/model/capability.go b/types/model/capability.go index 62e1abd8b..7ecfd848c 100644 --- a/types/model/capability.go +++ b/types/model/capability.go @@ -9,7 +9,7 @@ const ( CapabilityVision = Capability("vision") CapabilityEmbedding = Capability("embedding") CapabilityThinking = Capability("thinking") - CapabilityImageGeneration = Capability("image") + CapabilityImage = Capability("image") ) func (c Capability) String() string { diff --git a/x/imagegen/cli.go b/x/imagegen/cli.go index c1e61cfe2..a55a1b016 100644 --- a/x/imagegen/cli.go +++ b/x/imagegen/cli.go @@ -51,6 +51,7 @@ func RegisterFlags(cmd *cobra.Command) { cmd.Flags().Int("steps", 0, "Denoising steps (0 = model default)") cmd.Flags().Int("seed", 0, "Random seed (0 for random)") cmd.Flags().String("negative", "", "Negative prompt") + // Hide from main flags section - shown in separate section via AppendFlagsDocs cmd.Flags().MarkHidden("width") cmd.Flags().MarkHidden("height") cmd.Flags().MarkHidden("steps") @@ -58,6 +59,19 @@ func RegisterFlags(cmd *cobra.Command) { cmd.Flags().MarkHidden("negative") } +// AppendFlagsDocs appends image generation flags documentation to the command's usage template. +func AppendFlagsDocs(cmd *cobra.Command) { + usage := ` +Image Generation Flags (experimental): + --width int Image width + --height int Image height + --steps int Denoising steps + --seed int Random seed + --negative str Negative prompt +` + cmd.SetUsageTemplate(cmd.UsageTemplate() + usage) +} + // RunCLI handles the CLI for image generation models. // Returns true if it handled the request, false if the caller should continue with normal flow. // Supports flags: --width, --height, --steps, --seed, --negative @@ -91,9 +105,7 @@ func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, ke } // generateImageWithOptions generates an image with the given options. -// Note: opts are currently unused as the native API doesn't support size parameters. -// Use OpenAI-compatible endpoint (/v1/images/generations) for dimension control. -func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, _ ImageGenOptions) error { +func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, opts ImageGenOptions) error { client, err := api.ClientFromEnvironment() if err != nil { return err @@ -102,7 +114,12 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep req := &api.GenerateRequest{ Model: modelName, Prompt: prompt, - // Note: Size is only available via OpenAI-compatible /v1/images/generations endpoint + Width: int32(opts.Width), + Height: int32(opts.Height), + Steps: int32(opts.Steps), + } + if opts.Seed != 0 { + req.Options = map[string]any{"seed": opts.Seed} } if keepAlive != nil { req.KeepAlive = keepAlive @@ -116,32 +133,25 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep var stepBar *progress.StepBar var imageBase64 string err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error { - content := resp.Response - - // Handle progress updates - parse step info and switch to step bar - if strings.HasPrefix(content, "\rGenerating:") { - var step, total int - fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total) - if stepBar == nil && total > 0 { + // Handle progress updates using structured fields + if resp.Total > 0 { + if stepBar == nil { spinner.Stop() - stepBar = progress.NewStepBar("Generating", total) + stepBar = progress.NewStepBar("Generating", int(resp.Total)) p.Add("", stepBar) } - if stepBar != nil { - stepBar.Set(step) - } - return nil + stepBar.Set(int(resp.Completed)) } - // Handle final response with base64 image data - if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") { - imageBase64 = content[13:] + // Handle final response with image data + if resp.Done && resp.Image != "" { + imageBase64 = resp.Image } return nil }) - p.Stop() + p.StopAndClear() if err != nil { return err } @@ -179,6 +189,23 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio return err } + // Preload the model with the specified keepalive + p := progress.NewProgress(os.Stderr) + spinner := progress.NewSpinner("") + p.Add("", spinner) + + preloadReq := &api.GenerateRequest{ + Model: modelName, + KeepAlive: keepAlive, + } + if err := client.Generate(cmd.Context(), preloadReq, func(resp api.GenerateResponse) error { + return nil + }); err != nil { + p.StopAndClear() + return fmt.Errorf("failed to load model: %w", err) + } + p.StopAndClear() + scanner, err := readline.New(readline.Prompt{ Prompt: ">>> ", Placeholder: "Describe an image to generate (/help for commands)", @@ -216,7 +243,7 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio case strings.HasPrefix(line, "/bye"): return nil case strings.HasPrefix(line, "/?"), strings.HasPrefix(line, "/help"): - printInteractiveHelp(opts) + printInteractiveHelp() continue case strings.HasPrefix(line, "/set "): if err := handleSetCommand(line[5:], &opts); err != nil { @@ -235,12 +262,12 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio req := &api.GenerateRequest{ Model: modelName, Prompt: line, - Options: map[string]any{ - "num_ctx": opts.Width, - "num_gpu": opts.Height, - "num_predict": opts.Steps, - "seed": opts.Seed, - }, + Width: int32(opts.Width), + Height: int32(opts.Height), + Steps: int32(opts.Steps), + } + if opts.Seed != 0 { + req.Options = map[string]any{"seed": opts.Seed} } if keepAlive != nil { req.KeepAlive = keepAlive @@ -255,32 +282,25 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio var imageBase64 string err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error { - content := resp.Response - - // Handle progress updates - parse step info and switch to step bar - if strings.HasPrefix(content, "\rGenerating:") { - var step, total int - fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total) - if stepBar == nil && total > 0 { + // Handle progress updates using structured fields + if resp.Total > 0 { + if stepBar == nil { spinner.Stop() - stepBar = progress.NewStepBar("Generating", total) + stepBar = progress.NewStepBar("Generating", int(resp.Total)) p.Add("", stepBar) } - if stepBar != nil { - stepBar.Set(step) - } - return nil + stepBar.Set(int(resp.Completed)) } - // Handle final response with base64 image data - if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") { - imageBase64 = content[13:] + // Handle final response with image data + if resp.Done && resp.Image != "" { + imageBase64 = resp.Image } return nil }) - p.Stop() + p.StopAndClear() if err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) continue @@ -331,12 +351,13 @@ func sanitizeFilename(s string) string { } // printInteractiveHelp prints help for interactive mode commands. -func printInteractiveHelp(opts ImageGenOptions) { +// TODO: reconcile /set commands with /set parameter in text gen REPL (cmd/cmd.go) +func printInteractiveHelp() { fmt.Fprintln(os.Stderr, "Commands:") - fmt.Fprintln(os.Stderr, " /set width Set image width (current:", opts.Width, ")") - fmt.Fprintln(os.Stderr, " /set height Set image height (current:", opts.Height, ")") - fmt.Fprintln(os.Stderr, " /set steps Set denoising steps (current:", opts.Steps, ")") - fmt.Fprintln(os.Stderr, " /set seed Set random seed (current:", opts.Seed, ", 0=random)") + fmt.Fprintln(os.Stderr, " /set width Set image width") + fmt.Fprintln(os.Stderr, " /set height Set image height") + fmt.Fprintln(os.Stderr, " /set steps Set denoising steps") + fmt.Fprintln(os.Stderr, " /set seed Set random seed") fmt.Fprintln(os.Stderr, " /set negative Set negative prompt") fmt.Fprintln(os.Stderr, " /show Show current settings") fmt.Fprintln(os.Stderr, " /bye Exit") diff --git a/x/imagegen/runner/runner.go b/x/imagegen/runner/runner.go index ede11e765..bc0458f4c 100644 --- a/x/imagegen/runner/runner.go +++ b/x/imagegen/runner/runner.go @@ -36,6 +36,8 @@ type Response struct { Content string `json:"content,omitempty"` Image string `json:"image,omitempty"` // Base64-encoded PNG Done bool `json:"done"` + Step int `json:"step,omitempty"` + Total int `json:"total,omitempty"` } // Server holds the model and handles requests @@ -167,8 +169,9 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) { Seed: req.Seed, Progress: func(step, total int) { resp := Response{ - Content: fmt.Sprintf("\rGenerating: step %d/%d", step, total), - Done: false, + Step: step, + Total: total, + Done: false, } data, _ := json.Marshal(resp) w.Write(data) diff --git a/x/imagegen/server.go b/x/imagegen/server.go index 7c55ad77a..d7d282d8e 100644 --- a/x/imagegen/server.go +++ b/x/imagegen/server.go @@ -4,7 +4,6 @@ import ( "bufio" "bytes" "context" - "encoding/base64" "encoding/json" "errors" "fmt" @@ -232,11 +231,13 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f Prompt string `json:"prompt"` Width int32 `json:"width,omitempty"` Height int32 `json:"height,omitempty"` + Steps int32 `json:"steps,omitempty"` Seed int64 `json:"seed,omitempty"` }{ Prompt: req.Prompt, Width: req.Width, Height: req.Height, + Steps: req.Steps, Seed: seed, } @@ -279,15 +280,11 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f // Convert to llm.CompletionResponse cresp := llm.CompletionResponse{ - Content: raw.Content, - Done: raw.Done, - Step: raw.Step, - Total: raw.Total, - } - if raw.Image != "" { - if data, err := base64.StdEncoding.DecodeString(raw.Image); err == nil { - cresp.Image = data - } + Content: raw.Content, + Done: raw.Done, + Step: raw.Step, + TotalSteps: raw.Total, + Image: raw.Image, } fn(cresp)