mirror of
https://github.com/ollama/ollama.git
synced 2026-01-29 07:12:03 +03:00
x/imagegen: clean up image generation code (#13725)
This commit is contained in:
@@ -1464,6 +1464,11 @@ type CompletionRequest struct {
|
||||
|
||||
// TopLogprobs specifies the number of most likely alternative tokens to return (0-20)
|
||||
TopLogprobs int
|
||||
|
||||
// Image generation fields
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
// DoneReason represents the reason why a completion response is done
|
||||
@@ -1512,6 +1517,11 @@ type CompletionResponse struct {
|
||||
|
||||
// Logprobs contains log probability information if requested
|
||||
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||
|
||||
// Image generation fields
|
||||
Image []byte `json:"image,omitempty"` // Generated image
|
||||
Step int `json:"step,omitempty"` // Current generation step
|
||||
Total int `json:"total,omitempty"` // Total generation steps
|
||||
}
|
||||
|
||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||
|
||||
@@ -51,7 +51,6 @@ import (
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
imagegenapi "github.com/ollama/ollama/x/imagegen/api"
|
||||
)
|
||||
|
||||
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
|
||||
@@ -164,29 +163,6 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
||||
return runner.llama, model, &opts, nil
|
||||
}
|
||||
|
||||
// ScheduleImageGenRunner schedules an image generation model runner.
|
||||
// This implements the imagegenapi.RunnerScheduler interface.
|
||||
func (s *Server) ScheduleImageGenRunner(c *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error) {
|
||||
m := &Model{
|
||||
Name: modelName,
|
||||
ShortName: modelName,
|
||||
ModelPath: modelName, // For image gen, ModelPath is just the model name
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
}
|
||||
|
||||
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, opts, keepAlive)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
case runner = <-runnerCh:
|
||||
case err := <-errCh:
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return runner.llama, nil
|
||||
}
|
||||
|
||||
func signinURL() (string, error) {
|
||||
pubKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
@@ -214,12 +190,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if this is a known image generation model
|
||||
if imagegen.ResolveModelName(req.Model) != "" {
|
||||
imagegenapi.HandleGenerateRequest(c, s, req.Model, req.Prompt, req.KeepAlive, streamResponse)
|
||||
return
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
// Ideally this is "invalid model name" but we're keeping with
|
||||
@@ -1587,13 +1557,12 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
// Experimental OpenAI-compatible image generation endpoint
|
||||
r.POST("/v1/images/generations", s.handleImageGeneration)
|
||||
|
||||
// Inference (Anthropic compatibility)
|
||||
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||
|
||||
// Experimental image generation support
|
||||
imagegenapi.RegisterRoutes(r, s)
|
||||
|
||||
if rc != nil {
|
||||
// wrap old with new
|
||||
rs := ®istry.Local{
|
||||
@@ -1911,6 +1880,62 @@ func toolCallId() string {
|
||||
return "call_" + strings.ToLower(string(b))
|
||||
}
|
||||
|
||||
func (s *Server) handleImageGeneration(c *gin.Context) {
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Size string `json:"size"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
m, err := GetModel(req.Model)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, api.Options{}, nil)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
case runner = <-runnerCh:
|
||||
case err := <-errCh:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse size (e.g., "1024x768") into width and height
|
||||
width, height := int32(1024), int32(1024)
|
||||
if req.Size != "" {
|
||||
if _, err := fmt.Sscanf(req.Size, "%dx%d", &width, &height); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid size format, expected WxH"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var image []byte
|
||||
err = runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: req.Prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
}, func(resp llm.CompletionResponse) {
|
||||
if len(resp.Image) > 0 {
|
||||
image = resp.Image
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"created": time.Now().Unix(),
|
||||
"data": []gin.H{{"b64_json": base64.StdEncoding.EncodeToString(image)}},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) ChatHandler(c *gin.Context) {
|
||||
checkpointStart := time.Now()
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"os"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -17,7 +16,6 @@ import (
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
@@ -807,32 +805,8 @@ func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return n
|
||||
func (s *mockLlm) HasExited() bool { return false }
|
||||
func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil }
|
||||
|
||||
// TestImageGenCapabilityDetection verifies that models with "image" capability
|
||||
// are correctly identified and routed differently from language models.
|
||||
func TestImageGenCapabilityDetection(t *testing.T) {
|
||||
// Model with image capability should be detected
|
||||
imageModel := &Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
}
|
||||
require.True(t, slices.Contains(imageModel.Config.Capabilities, "image"))
|
||||
|
||||
// Model without image capability should not be detected
|
||||
langModel := &Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"completion"},
|
||||
},
|
||||
}
|
||||
require.False(t, slices.Contains(langModel.Config.Capabilities, "image"))
|
||||
|
||||
// Empty capabilities should not match
|
||||
emptyModel := &Model{}
|
||||
require.False(t, slices.Contains(emptyModel.Config.Capabilities, "image"))
|
||||
}
|
||||
|
||||
// TestImageGenRunnerCanBeEvicted verifies that an image generation model
|
||||
// loaded in the scheduler can be evicted by a language model request.
|
||||
// loaded in the scheduler can be evicted when idle.
|
||||
func TestImageGenRunnerCanBeEvicted(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
defer done()
|
||||
@@ -864,3 +838,59 @@ func TestImageGenRunnerCanBeEvicted(t *testing.T) {
|
||||
require.NotNil(t, runner)
|
||||
require.Equal(t, "/fake/image/model", runner.modelPath)
|
||||
}
|
||||
|
||||
// TestImageGenSchedulerCoexistence verifies that image generation models
|
||||
// can coexist with language models in the scheduler and VRAM is tracked correctly.
|
||||
func TestImageGenSchedulerCoexistence(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
s.getGpuFn = getGpuFn
|
||||
s.getSystemInfoFn = getSystemInfoFn
|
||||
|
||||
// Load both an imagegen runner and a language model runner
|
||||
imageGenRunner := &runnerRef{
|
||||
model: &Model{Name: "flux", ModelPath: "/fake/flux/model"},
|
||||
modelPath: "/fake/flux/model",
|
||||
llama: &mockLlm{vramSize: 8 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{{Library: "Metal"}: 8 * format.GigaByte}},
|
||||
sessionDuration: 10 * time.Millisecond,
|
||||
numParallel: 1,
|
||||
refCount: 0,
|
||||
}
|
||||
|
||||
langModelRunner := &runnerRef{
|
||||
model: &Model{Name: "llama3", ModelPath: "/fake/llama3/model"},
|
||||
modelPath: "/fake/llama3/model",
|
||||
llama: &mockLlm{vramSize: 4 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{{Library: "Metal"}: 4 * format.GigaByte}},
|
||||
sessionDuration: 10 * time.Millisecond,
|
||||
numParallel: 1,
|
||||
refCount: 0,
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
s.loaded["/fake/flux/model"] = imageGenRunner
|
||||
s.loaded["/fake/llama3/model"] = langModelRunner
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Verify both are loaded
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 2)
|
||||
require.NotNil(t, s.loaded["/fake/flux/model"])
|
||||
require.NotNil(t, s.loaded["/fake/llama3/model"])
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Verify updateFreeSpace accounts for both
|
||||
gpus := []ml.DeviceInfo{
|
||||
{
|
||||
DeviceID: ml.DeviceID{Library: "Metal"},
|
||||
TotalMemory: 24 * format.GigaByte,
|
||||
FreeMemory: 24 * format.GigaByte,
|
||||
},
|
||||
}
|
||||
s.updateFreeSpace(gpus)
|
||||
|
||||
// Free memory should be reduced by both models
|
||||
expectedFree := uint64(24*format.GigaByte) - uint64(8*format.GigaByte) - uint64(4*format.GigaByte)
|
||||
require.Equal(t, expectedFree, gpus[0].FreeMemory)
|
||||
}
|
||||
|
||||
@@ -1,231 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
// RunnerScheduler is the interface for scheduling a model runner.
|
||||
// This is implemented by server.Server to avoid circular imports.
|
||||
type RunnerScheduler interface {
|
||||
ScheduleImageGenRunner(ctx *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error)
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the image generation API routes.
|
||||
func RegisterRoutes(r gin.IRouter, scheduler RunnerScheduler) {
|
||||
r.POST("/v1/images/generations", func(c *gin.Context) {
|
||||
ImageGenerationHandler(c, scheduler)
|
||||
})
|
||||
}
|
||||
|
||||
// ImageGenerationHandler handles OpenAI-compatible image generation requests.
|
||||
func ImageGenerationHandler(c *gin.Context, scheduler RunnerScheduler) {
|
||||
var req ImageGenerationRequest
|
||||
if err := c.BindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": err.Error()}})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if req.Model == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "model is required"}})
|
||||
return
|
||||
}
|
||||
if req.Prompt == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "prompt is required"}})
|
||||
return
|
||||
}
|
||||
|
||||
// Apply defaults
|
||||
if req.N == 0 {
|
||||
req.N = 1
|
||||
}
|
||||
if req.Size == "" {
|
||||
req.Size = "1024x1024"
|
||||
}
|
||||
if req.ResponseFormat == "" {
|
||||
req.ResponseFormat = "b64_json"
|
||||
}
|
||||
|
||||
// Verify model exists
|
||||
if imagegen.ResolveModelName(req.Model) == "" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": gin.H{"message": fmt.Sprintf("model %q not found", req.Model)}})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse size
|
||||
width, height := parseSize(req.Size)
|
||||
|
||||
// Build options - we repurpose NumCtx/NumGPU for width/height
|
||||
opts := api.Options{}
|
||||
opts.NumCtx = int(width)
|
||||
opts.NumGPU = int(height)
|
||||
|
||||
// Schedule runner
|
||||
runner, err := scheduler.ScheduleImageGenRunner(c, req.Model, opts, nil)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
status = http.StatusNotFound
|
||||
}
|
||||
c.JSON(status, gin.H{"error": gin.H{"message": err.Error()}})
|
||||
return
|
||||
}
|
||||
|
||||
// Build completion request
|
||||
completionReq := llm.CompletionRequest{
|
||||
Prompt: req.Prompt,
|
||||
Options: &opts,
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
handleStreamingResponse(c, runner, completionReq, req.ResponseFormat)
|
||||
} else {
|
||||
handleNonStreamingResponse(c, runner, completionReq, req.ResponseFormat)
|
||||
}
|
||||
}
|
||||
|
||||
func handleStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
var imageBase64 string
|
||||
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
|
||||
if resp.Done {
|
||||
imageBase64 = extractBase64(resp.Content)
|
||||
} else {
|
||||
progress := parseProgress(resp.Content)
|
||||
if progress.Total > 0 {
|
||||
c.SSEvent("progress", progress)
|
||||
c.Writer.Flush()
|
||||
}
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
c.SSEvent("error", gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.SSEvent("done", buildResponse(imageBase64, format))
|
||||
}
|
||||
|
||||
func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
|
||||
var imageBase64 string
|
||||
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
|
||||
if resp.Done {
|
||||
imageBase64 = extractBase64(resp.Content)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": err.Error()}})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, buildResponse(imageBase64, format))
|
||||
}
|
||||
|
||||
func parseSize(size string) (int32, int32) {
|
||||
parts := strings.Split(size, "x")
|
||||
if len(parts) != 2 {
|
||||
return 1024, 1024
|
||||
}
|
||||
w, _ := strconv.Atoi(parts[0])
|
||||
h, _ := strconv.Atoi(parts[1])
|
||||
if w == 0 {
|
||||
w = 1024
|
||||
}
|
||||
if h == 0 {
|
||||
h = 1024
|
||||
}
|
||||
return int32(w), int32(h)
|
||||
}
|
||||
|
||||
func extractBase64(content string) string {
|
||||
if strings.HasPrefix(content, "IMAGE_BASE64:") {
|
||||
return content[13:]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseProgress(content string) ImageProgressEvent {
|
||||
var step, total int
|
||||
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
|
||||
return ImageProgressEvent{Step: step, Total: total}
|
||||
}
|
||||
|
||||
func buildResponse(imageBase64, format string) ImageGenerationResponse {
|
||||
resp := ImageGenerationResponse{
|
||||
Created: time.Now().Unix(),
|
||||
Data: make([]ImageData, 1),
|
||||
}
|
||||
|
||||
if imageBase64 == "" {
|
||||
return resp
|
||||
}
|
||||
|
||||
if format == "url" {
|
||||
// URL format not supported when using base64 transfer
|
||||
resp.Data[0].B64JSON = imageBase64
|
||||
} else {
|
||||
resp.Data[0].B64JSON = imageBase64
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// HandleGenerateRequest handles Ollama /api/generate requests for image gen models.
|
||||
// This allows routes.go to delegate image generation with minimal code.
|
||||
func HandleGenerateRequest(c *gin.Context, scheduler RunnerScheduler, modelName, prompt string, keepAlive *api.Duration, streamFn func(c *gin.Context, ch chan any)) {
|
||||
opts := api.Options{}
|
||||
|
||||
// Schedule runner
|
||||
runner, err := scheduler.ScheduleImageGenRunner(c, modelName, opts, keepAlive)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Build completion request
|
||||
completionReq := llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Options: &opts,
|
||||
}
|
||||
|
||||
// Stream responses via channel
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := runner.Completion(c.Request.Context(), completionReq, func(resp llm.CompletionResponse) {
|
||||
ch <- GenerateResponse{
|
||||
Model: modelName,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Response: resp.Content,
|
||||
Done: resp.Done,
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
// Log error but don't block - channel is already being consumed
|
||||
_ = err
|
||||
}
|
||||
}()
|
||||
|
||||
streamFn(c, ch)
|
||||
}
|
||||
|
||||
// GenerateResponse matches api.GenerateResponse structure for streaming.
|
||||
type GenerateResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Response string `json:"response"`
|
||||
Done bool `json:"done"`
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
// Package api provides OpenAI-compatible image generation API types.
|
||||
package api
|
||||
|
||||
// ImageGenerationRequest is an OpenAI-compatible image generation request.
|
||||
type ImageGenerationRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
// ImageGenerationResponse is an OpenAI-compatible image generation response.
|
||||
type ImageGenerationResponse struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []ImageData `json:"data"`
|
||||
}
|
||||
|
||||
// ImageData contains the generated image data.
|
||||
type ImageData struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
B64JSON string `json:"b64_json,omitempty"`
|
||||
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
||||
}
|
||||
|
||||
// ImageProgressEvent is sent during streaming to indicate generation progress.
|
||||
type ImageProgressEvent struct {
|
||||
Step int `json:"step"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
@@ -7,7 +7,6 @@ package imagegen
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -39,77 +38,17 @@ func DefaultOptions() ImageGenOptions {
|
||||
return ImageGenOptions{
|
||||
Width: 1024,
|
||||
Height: 1024,
|
||||
Steps: 9,
|
||||
Steps: 0, // 0 means model default
|
||||
Seed: 0, // 0 means random
|
||||
}
|
||||
}
|
||||
|
||||
// ModelInfo contains metadata about an image generation model.
|
||||
type ModelInfo struct {
|
||||
Architecture string
|
||||
ParameterCount int64
|
||||
Quantization string
|
||||
}
|
||||
|
||||
// GetModelInfo returns metadata about an image generation model.
|
||||
func GetModelInfo(modelName string) (*ModelInfo, error) {
|
||||
manifest, err := LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
info := &ModelInfo{}
|
||||
|
||||
// Read model_index.json for architecture, parameter count, and quantization
|
||||
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
|
||||
var index struct {
|
||||
Architecture string `json:"architecture"`
|
||||
ParameterCount int64 `json:"parameter_count"`
|
||||
Quantization string `json:"quantization"`
|
||||
}
|
||||
if json.Unmarshal(data, &index) == nil {
|
||||
info.Architecture = index.Architecture
|
||||
info.ParameterCount = index.ParameterCount
|
||||
info.Quantization = index.Quantization
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: detect quantization from tensor names if not in config
|
||||
if info.Quantization == "" {
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if strings.HasSuffix(layer.Name, ".weight_scale") {
|
||||
info.Quantization = "FP8"
|
||||
break
|
||||
}
|
||||
}
|
||||
if info.Quantization == "" {
|
||||
info.Quantization = "BF16"
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: estimate parameter count if not in config
|
||||
if info.ParameterCount == 0 {
|
||||
var totalSize int64
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
|
||||
totalSize += layer.Size
|
||||
}
|
||||
}
|
||||
}
|
||||
// Assume BF16 (2 bytes/param) as rough estimate
|
||||
info.ParameterCount = totalSize / 2
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// RegisterFlags adds image generation flags to the given command.
|
||||
// Flags are hidden since they only apply to image generation models.
|
||||
func RegisterFlags(cmd *cobra.Command) {
|
||||
cmd.Flags().Int("width", 1024, "Image width")
|
||||
cmd.Flags().Int("height", 1024, "Image height")
|
||||
cmd.Flags().Int("steps", 9, "Denoising steps")
|
||||
cmd.Flags().Int("steps", 0, "Denoising steps (0 = model default)")
|
||||
cmd.Flags().Int("seed", 0, "Random seed (0 for random)")
|
||||
cmd.Flags().String("negative", "", "Negative prompt")
|
||||
cmd.Flags().MarkHidden("width")
|
||||
@@ -152,23 +91,18 @@ func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, ke
|
||||
}
|
||||
|
||||
// generateImageWithOptions generates an image with the given options.
|
||||
func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, opts ImageGenOptions) error {
|
||||
// Note: opts are currently unused as the native API doesn't support size parameters.
|
||||
// Use OpenAI-compatible endpoint (/v1/images/generations) for dimension control.
|
||||
func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, _ ImageGenOptions) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build request with image gen options encoded in Options fields
|
||||
// NumCtx=width, NumGPU=height, NumPredict=steps, Seed=seed
|
||||
req := &api.GenerateRequest{
|
||||
Model: modelName,
|
||||
Prompt: prompt,
|
||||
Options: map[string]any{
|
||||
"num_ctx": opts.Width,
|
||||
"num_gpu": opts.Height,
|
||||
"num_predict": opts.Steps,
|
||||
"seed": opts.Seed,
|
||||
},
|
||||
// Note: Size is only available via OpenAI-compatible /v1/images/generations endpoint
|
||||
}
|
||||
if keepAlive != nil {
|
||||
req.KeepAlive = keepAlive
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"path/filepath"
|
||||
"runtime/pprof"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
||||
@@ -48,7 +49,7 @@ func main() {
|
||||
// Image generation params
|
||||
width := flag.Int("width", 1024, "Image width")
|
||||
height := flag.Int("height", 1024, "Image height")
|
||||
steps := flag.Int("steps", 9, "Denoising steps")
|
||||
steps := flag.Int("steps", 0, "Denoising steps (0 = model default)")
|
||||
seed := flag.Int64("seed", 42, "Random seed")
|
||||
out := flag.String("output", "output.png", "Output path")
|
||||
|
||||
|
||||
@@ -175,3 +175,63 @@ func (m *ModelManifest) HasTensorLayers() bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ModelInfo contains metadata about an image generation model.
|
||||
type ModelInfo struct {
|
||||
Architecture string
|
||||
ParameterCount int64
|
||||
Quantization string
|
||||
}
|
||||
|
||||
// GetModelInfo returns metadata about an image generation model.
|
||||
func GetModelInfo(modelName string) (*ModelInfo, error) {
|
||||
manifest, err := LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
info := &ModelInfo{}
|
||||
|
||||
// Read model_index.json for architecture, parameter count, and quantization
|
||||
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
|
||||
var index struct {
|
||||
Architecture string `json:"architecture"`
|
||||
ParameterCount int64 `json:"parameter_count"`
|
||||
Quantization string `json:"quantization"`
|
||||
}
|
||||
if json.Unmarshal(data, &index) == nil {
|
||||
info.Architecture = index.Architecture
|
||||
info.ParameterCount = index.ParameterCount
|
||||
info.Quantization = index.Quantization
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: detect quantization from tensor names if not in config
|
||||
if info.Quantization == "" {
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if strings.HasSuffix(layer.Name, ".weight_scale") {
|
||||
info.Quantization = "FP8"
|
||||
break
|
||||
}
|
||||
}
|
||||
if info.Quantization == "" {
|
||||
info.Quantization = "BF16"
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: estimate parameter count if not in config
|
||||
if info.ParameterCount == 0 {
|
||||
var totalSize int64
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
|
||||
totalSize += layer.Size
|
||||
}
|
||||
}
|
||||
}
|
||||
// Assume BF16 (2 bytes/param) as rough estimate
|
||||
info.ParameterCount = totalSize / 2
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
@@ -95,8 +95,3 @@ func EstimateVRAM(modelName string) uint64 {
|
||||
}
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
// HasTensorLayers checks if the given model has tensor layers.
|
||||
func HasTensorLayers(modelName string) bool {
|
||||
return ResolveModelName(modelName) != ""
|
||||
}
|
||||
|
||||
@@ -94,13 +94,6 @@ func TestEstimateVRAMDefault(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasTensorLayers(t *testing.T) {
|
||||
// Non-existent model should return false
|
||||
if HasTensorLayers("nonexistent-model") {
|
||||
t.Error("HasTensorLayers() should return false for non-existent model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveModelName(t *testing.T) {
|
||||
// Non-existent model should return empty string
|
||||
result := ResolveModelName("nonexistent-model")
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
@@ -172,7 +173,7 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
cfg.Height = 1024
|
||||
}
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 30
|
||||
cfg.Steps = 50
|
||||
}
|
||||
if cfg.CFGScale <= 0 {
|
||||
cfg.CFGScale = 4.0
|
||||
|
||||
@@ -194,7 +194,7 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
|
||||
cfg.Height = 1024
|
||||
}
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 9 // Turbo default
|
||||
cfg.Steps = 9 // Z-Image turbo default
|
||||
}
|
||||
if cfg.CFGScale <= 0 {
|
||||
cfg.CFGScale = 4.0
|
||||
|
||||
@@ -136,16 +136,8 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Apply defaults
|
||||
if req.Width <= 0 {
|
||||
req.Width = 1024
|
||||
}
|
||||
if req.Height <= 0 {
|
||||
req.Height = 1024
|
||||
}
|
||||
if req.Steps <= 0 {
|
||||
req.Steps = 9
|
||||
}
|
||||
// Model applies its own defaults for width/height/steps
|
||||
// Only seed needs to be set here if not provided
|
||||
if req.Seed <= 0 {
|
||||
req.Seed = time.Now().UnixNano()
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -25,6 +26,11 @@ import (
|
||||
)
|
||||
|
||||
// Server wraps an image generation subprocess to implement llm.LlamaServer.
|
||||
//
|
||||
// This implementation is compatible with Ollama's scheduler and can be loaded/unloaded
|
||||
// like any other model. The plan is to eventually bring this into the llm/ package
|
||||
// and evolve llm/ to support MLX and multimodal models. For now, keeping the code
|
||||
// separate allows for independent iteration on image generation support.
|
||||
type Server struct {
|
||||
mu sync.Mutex
|
||||
cmd *exec.Cmd
|
||||
@@ -37,22 +43,6 @@ type Server struct {
|
||||
lastErrLock sync.Mutex
|
||||
}
|
||||
|
||||
// completionRequest is sent to the subprocess
|
||||
type completionRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
// completionResponse is received from the subprocess
|
||||
type completionResponse struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Done bool `json:"done"`
|
||||
}
|
||||
|
||||
// NewServer spawns a new image generation subprocess and waits until it's ready.
|
||||
func NewServer(modelName string) (*Server, error) {
|
||||
// Validate platform support before attempting to start
|
||||
@@ -139,7 +129,6 @@ func NewServer(modelName string) (*Server, error) {
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
slog.Warn("image-runner", "msg", line)
|
||||
// Capture last error line for better error reporting
|
||||
s.lastErrLock.Lock()
|
||||
s.lastErr = line
|
||||
s.lastErrLock.Unlock()
|
||||
@@ -171,7 +160,6 @@ func (s *Server) ModelPath() string {
|
||||
return s.modelName
|
||||
}
|
||||
|
||||
// Load is called by the scheduler after the server is created.
|
||||
func (s *Server) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -204,20 +192,16 @@ func (s *Server) waitUntilRunning() error {
|
||||
for {
|
||||
select {
|
||||
case err := <-s.done:
|
||||
// Include last stderr line for better error context
|
||||
s.lastErrLock.Lock()
|
||||
lastErr := s.lastErr
|
||||
s.lastErrLock.Unlock()
|
||||
if lastErr != "" {
|
||||
return fmt.Errorf("image runner failed: %s (exit: %v)", lastErr, err)
|
||||
// Include recent stderr lines for better error context
|
||||
errMsg := s.getLastErr()
|
||||
if errMsg != "" {
|
||||
return fmt.Errorf("image runner failed: %s (exit: %v)", errMsg, err)
|
||||
}
|
||||
return fmt.Errorf("image runner exited unexpectedly: %w", err)
|
||||
case <-timeout:
|
||||
s.lastErrLock.Lock()
|
||||
lastErr := s.lastErr
|
||||
s.lastErrLock.Unlock()
|
||||
if lastErr != "" {
|
||||
return fmt.Errorf("timeout waiting for image runner: %s", lastErr)
|
||||
errMsg := s.getLastErr()
|
||||
if errMsg != "" {
|
||||
return fmt.Errorf("timeout waiting for image runner: %s", errMsg)
|
||||
}
|
||||
return errors.New("timeout waiting for image runner to start")
|
||||
case <-ticker.C:
|
||||
@@ -229,44 +213,39 @@ func (s *Server) waitUntilRunning() error {
|
||||
}
|
||||
}
|
||||
|
||||
// WaitUntilRunning implements the LlamaServer interface (no-op since NewServer waits).
|
||||
func (s *Server) WaitUntilRunning(ctx context.Context) error {
|
||||
return nil
|
||||
// getLastErr returns the last stderr line.
|
||||
func (s *Server) getLastErr() string {
|
||||
s.lastErrLock.Lock()
|
||||
defer s.lastErrLock.Unlock()
|
||||
return s.lastErr
|
||||
}
|
||||
|
||||
// Completion generates an image from the prompt via the subprocess.
|
||||
func (s *Server) WaitUntilRunning(ctx context.Context) error { return nil }
|
||||
|
||||
func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
// Build request
|
||||
creq := completionRequest{
|
||||
seed := req.Seed
|
||||
if seed == 0 {
|
||||
seed = time.Now().UnixNano()
|
||||
}
|
||||
|
||||
// Build request for subprocess
|
||||
creq := struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}{
|
||||
Prompt: req.Prompt,
|
||||
Width: 1024,
|
||||
Height: 1024,
|
||||
Steps: 9,
|
||||
Seed: time.Now().UnixNano(),
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
Seed: seed,
|
||||
}
|
||||
|
||||
if req.Options != nil {
|
||||
if req.Options.NumCtx > 0 && req.Options.NumCtx <= 4096 {
|
||||
creq.Width = int32(req.Options.NumCtx)
|
||||
}
|
||||
if req.Options.NumGPU > 0 && req.Options.NumGPU <= 4096 {
|
||||
creq.Height = int32(req.Options.NumGPU)
|
||||
}
|
||||
if req.Options.NumPredict > 0 && req.Options.NumPredict <= 100 {
|
||||
creq.Steps = req.Options.NumPredict
|
||||
}
|
||||
if req.Options.Seed > 0 {
|
||||
creq.Seed = int64(req.Options.Seed)
|
||||
}
|
||||
}
|
||||
|
||||
// Encode request body
|
||||
body, err := json.Marshal(creq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send request to subprocess
|
||||
url := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
@@ -281,30 +260,40 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("completion request failed: %d", resp.StatusCode)
|
||||
return fmt.Errorf("request failed: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Stream responses - use large buffer for base64 image data
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 1024*1024), 16*1024*1024) // 16MB max
|
||||
for scanner.Scan() {
|
||||
var cresp completionResponse
|
||||
if err := json.Unmarshal(scanner.Bytes(), &cresp); err != nil {
|
||||
// Parse subprocess response (has singular "image" field)
|
||||
var raw struct {
|
||||
Image string `json:"image,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Done bool `json:"done"`
|
||||
Step int `json:"step,omitempty"`
|
||||
Total int `json:"total,omitempty"`
|
||||
}
|
||||
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
content := cresp.Content
|
||||
// If this is the final response with an image, encode it in the content
|
||||
if cresp.Done && cresp.Image != "" {
|
||||
content = "IMAGE_BASE64:" + cresp.Image
|
||||
// Convert to llm.CompletionResponse
|
||||
cresp := llm.CompletionResponse{
|
||||
Content: raw.Content,
|
||||
Done: raw.Done,
|
||||
Step: raw.Step,
|
||||
Total: raw.Total,
|
||||
}
|
||||
if raw.Image != "" {
|
||||
if data, err := base64.StdEncoding.DecodeString(raw.Image); err == nil {
|
||||
cresp.Image = data
|
||||
}
|
||||
}
|
||||
|
||||
fn(llm.CompletionResponse{
|
||||
Content: content,
|
||||
Done: cresp.Done,
|
||||
})
|
||||
fn(cresp)
|
||||
if cresp.Done {
|
||||
break
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -346,22 +335,18 @@ func (s *Server) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||
return s.vramSize
|
||||
}
|
||||
|
||||
// Embedding is not supported for image generation models.
|
||||
func (s *Server) Embedding(ctx context.Context, input string) ([]float32, int, error) {
|
||||
return nil, 0, errors.New("embedding not supported for image generation models")
|
||||
return nil, 0, errors.New("not supported")
|
||||
}
|
||||
|
||||
// Tokenize is not supported for image generation models.
|
||||
func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
return nil, errors.New("tokenize not supported for image generation models")
|
||||
return nil, errors.New("not supported")
|
||||
}
|
||||
|
||||
// Detokenize is not supported for image generation models.
|
||||
func (s *Server) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||
return "", errors.New("detokenize not supported for image generation models")
|
||||
return "", errors.New("not supported")
|
||||
}
|
||||
|
||||
// Pid returns the subprocess PID.
|
||||
func (s *Server) Pid() int {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -371,17 +356,9 @@ func (s *Server) Pid() int {
|
||||
return -1
|
||||
}
|
||||
|
||||
// GetPort returns the subprocess port.
|
||||
func (s *Server) GetPort() int {
|
||||
return s.port
|
||||
}
|
||||
func (s *Server) GetPort() int { return s.port }
|
||||
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil }
|
||||
|
||||
// GetDeviceInfos returns nil since we don't track GPU info.
|
||||
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasExited returns true if the subprocess has exited.
|
||||
func (s *Server) HasExited() bool {
|
||||
select {
|
||||
case <-s.done:
|
||||
|
||||
Reference in New Issue
Block a user