diff --git a/x/imagegen/manifest.go b/x/imagegen/manifest.go index f662915ee..7692d2b09 100644 --- a/x/imagegen/manifest.go +++ b/x/imagegen/manifest.go @@ -161,6 +161,17 @@ func (m *ModelManifest) HasTensorLayers() bool { return false } +// TotalTensorSize returns the total size in bytes of all tensor layers. +func (m *ModelManifest) TotalTensorSize() int64 { + var total int64 + for _, layer := range m.Manifest.Layers { + if layer.MediaType == "application/vnd.ollama.image.tensor" { + total += layer.Size + } + } + return total +} + // ModelInfo contains metadata about an image generation model. type ModelInfo struct { Architecture string diff --git a/x/imagegen/manifest_test.go b/x/imagegen/manifest_test.go index a387d272c..281cb0bc5 100644 --- a/x/imagegen/manifest_test.go +++ b/x/imagegen/manifest_test.go @@ -5,6 +5,37 @@ import ( "testing" ) +func TestTotalTensorSize(t *testing.T) { + m := &ModelManifest{ + Manifest: &Manifest{ + Layers: []ManifestLayer{ + {MediaType: "application/vnd.ollama.image.tensor", Size: 1000}, + {MediaType: "application/vnd.ollama.image.tensor", Size: 2000}, + {MediaType: "application/vnd.ollama.image.json", Size: 500}, // not a tensor + {MediaType: "application/vnd.ollama.image.tensor", Size: 3000}, + }, + }, + } + + got := m.TotalTensorSize() + want := int64(6000) + if got != want { + t.Errorf("TotalTensorSize() = %d, want %d", got, want) + } +} + +func TestTotalTensorSizeEmpty(t *testing.T) { + m := &ModelManifest{ + Manifest: &Manifest{ + Layers: []ManifestLayer{}, + }, + } + + if got := m.TotalTensorSize(); got != 0 { + t.Errorf("TotalTensorSize() = %d, want 0", got) + } +} + func TestManifestAndBlobDirsRespectOLLAMAModels(t *testing.T) { modelsDir := filepath.Join(t.TempDir(), "models") diff --git a/x/imagegen/memory.go b/x/imagegen/memory.go index 57dc4667c..5672503f3 100644 --- a/x/imagegen/memory.go +++ b/x/imagegen/memory.go @@ -16,18 +16,9 @@ import ( "runtime" ) -// GB is a convenience constant for gigabytes. -const GB = 1024 * 1024 * 1024 - // SupportedBackends lists the backends that support image generation. var SupportedBackends = []string{"metal", "cuda", "cpu"} -// modelVRAMEstimates maps pipeline class names to their estimated VRAM requirements. -var modelVRAMEstimates = map[string]uint64{ - "ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE) - "FluxPipeline": 20 * GB, // ~20GB for Flux -} - // CheckPlatformSupport validates that image generation is supported on the current platform. // Returns nil if supported, or an error describing why it's not supported. func CheckPlatformSupport() error { @@ -47,17 +38,6 @@ func CheckPlatformSupport() error { } } -// CheckMemoryRequirements validates that there's enough memory for image generation. -// Returns nil if memory is sufficient, or an error if not. -func CheckMemoryRequirements(modelName string, availableMemory uint64) error { - required := EstimateVRAM(modelName) - if availableMemory < required { - return fmt.Errorf("insufficient memory for image generation: need %d GB, have %d GB", - required/GB, availableMemory/GB) - } - return nil -} - // ResolveModelName checks if a model name is a known image generation model. // Returns the normalized model name if found, empty string otherwise. func ResolveModelName(modelName string) string { @@ -68,16 +48,6 @@ func ResolveModelName(modelName string) string { return "" } -// EstimateVRAM returns the estimated VRAM needed for an image generation model. -// Returns a conservative default of 21GB if the model type cannot be determined. -func EstimateVRAM(modelName string) uint64 { - className := DetectModelType(modelName) - if estimate, ok := modelVRAMEstimates[className]; ok { - return estimate - } - return 21 * GB -} - // DetectModelType reads model_index.json and returns the model type. // Checks both "architecture" (Ollama format) and "_class_name" (diffusers format). // Returns empty string if detection fails. diff --git a/x/imagegen/memory_test.go b/x/imagegen/memory_test.go index 180021f6b..531cffda2 100644 --- a/x/imagegen/memory_test.go +++ b/x/imagegen/memory_test.go @@ -30,69 +30,6 @@ func TestCheckPlatformSupport(t *testing.T) { } } -func TestCheckMemoryRequirements(t *testing.T) { - tests := []struct { - name string - availableMemory uint64 - wantErr bool - }{ - { - name: "sufficient memory", - availableMemory: 32 * GB, - wantErr: false, - }, - { - name: "exactly enough memory", - availableMemory: 21 * GB, - wantErr: false, - }, - { - name: "insufficient memory", - availableMemory: 16 * GB, - wantErr: true, - }, - { - name: "zero memory", - availableMemory: 0, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Use a non-existent model name which will default to 21GB estimate - err := CheckMemoryRequirements("nonexistent-model", tt.availableMemory) - if (err != nil) != tt.wantErr { - t.Errorf("CheckMemoryRequirements() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - -func TestModelVRAMEstimates(t *testing.T) { - // Verify the VRAM estimates map has expected entries - expected := map[string]uint64{ - "ZImagePipeline": 21 * GB, - "FluxPipeline": 20 * GB, - } - - for name, expectedVRAM := range expected { - if actual, ok := modelVRAMEstimates[name]; !ok { - t.Errorf("Missing VRAM estimate for %s", name) - } else if actual != expectedVRAM { - t.Errorf("VRAM estimate for %s = %d GB, want %d GB", name, actual/GB, expectedVRAM/GB) - } - } -} - -func TestEstimateVRAMDefault(t *testing.T) { - // Non-existent model should return default 21GB - vram := EstimateVRAM("nonexistent-model-that-does-not-exist") - if vram != 21*GB { - t.Errorf("EstimateVRAM() = %d GB, want 21 GB", vram/GB) - } -} - func TestResolveModelName(t *testing.T) { // Non-existent model should return empty string result := ResolveModelName("nonexistent-model") diff --git a/x/imagegen/runner/runner.go b/x/imagegen/runner/runner.go index baa0eb4bf..8fe5c2de1 100644 --- a/x/imagegen/runner/runner.go +++ b/x/imagegen/runner/runner.go @@ -78,14 +78,6 @@ func Execute(args []string) error { slog.Info("MLX library initialized") slog.Info("starting image runner", "model", *modelName, "port", *port) - // Check memory requirements before loading - requiredMemory := imagegen.EstimateVRAM(*modelName) - availableMemory := mlx.GetMemoryLimit() - if availableMemory > 0 && availableMemory < requiredMemory { - return fmt.Errorf("insufficient memory for image generation: need %d GB, have %d GB", - requiredMemory/(1024*1024*1024), availableMemory/(1024*1024*1024)) - } - // Detect model type and load appropriate model modelType := imagegen.DetectModelType(*modelName) slog.Info("detected model type", "type", modelType) diff --git a/x/imagegen/server.go b/x/imagegen/server.go index d7d282d8e..b645e3065 100644 --- a/x/imagegen/server.go +++ b/x/imagegen/server.go @@ -104,11 +104,17 @@ func NewServer(modelName string) (*Server, error) { slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal) } + // Get total weight size from manifest + var weightSize uint64 + if manifest, err := LoadManifest(modelName); err == nil { + weightSize = uint64(manifest.TotalTensorSize()) + } + s := &Server{ cmd: cmd, port: port, modelName: modelName, - vramSize: EstimateVRAM(modelName), + vramSize: weightSize, done: make(chan error, 1), client: &http.Client{Timeout: 10 * time.Minute}, } diff --git a/x/imagegen/server_test.go b/x/imagegen/server_test.go index c3957236a..396aa140b 100644 --- a/x/imagegen/server_test.go +++ b/x/imagegen/server_test.go @@ -38,40 +38,6 @@ func TestPlatformSupport(t *testing.T) { } } -// TestMemoryRequirementsError verifies memory check returns clear error. -func TestMemoryRequirementsError(t *testing.T) { - // Test with insufficient memory - err := CheckMemoryRequirements("test-model", 8*GB) - if err == nil { - t.Error("Expected error for insufficient memory (8GB < 21GB default)") - } - - // Test with sufficient memory - err = CheckMemoryRequirements("test-model", 32*GB) - if err != nil { - t.Errorf("Expected no error for sufficient memory (32GB), got: %v", err) - } -} - -// TestEstimateVRAMReturnsReasonableDefaults verifies VRAM estimates are sensible. -func TestEstimateVRAMReturnsReasonableDefaults(t *testing.T) { - // Unknown model should return default (21GB) - vram := EstimateVRAM("unknown-model") - if vram < 10*GB || vram > 100*GB { - t.Errorf("VRAM estimate %d GB is outside reasonable range (10-100 GB)", vram/GB) - } - - // Verify known pipeline estimates exist and are reasonable - for name, estimate := range modelVRAMEstimates { - if estimate < 10*GB { - t.Errorf("VRAM estimate for %s (%d GB) is suspiciously low", name, estimate/GB) - } - if estimate > 200*GB { - t.Errorf("VRAM estimate for %s (%d GB) is suspiciously high", name, estimate/GB) - } - } -} - // TestServerInterfaceCompliance verifies Server implements llm.LlamaServer. // This is a compile-time check but we document it as a test. func TestServerInterfaceCompliance(t *testing.T) {