From 2b2cda7a2b5ad55561c023cddccfb637713da39e Mon Sep 17 00:00:00 2001 From: Parth Sareen Date: Fri, 9 Jan 2026 11:53:36 -0800 Subject: [PATCH] api: implement anthropic api (#13600) * api: add Anthropic Messages API compatibility layer Add middleware to support the Anthropic Messages API format at /v1/messages. This enables tools like Claude Code to work with Ollama local and cloud models through the Anthropic API interface. --- anthropic/anthropic.go | 778 ++++++++++++++++++++++ anthropic/anthropic_test.go | 953 +++++++++++++++++++++++++++ docs/README.md | 1 + docs/api/anthropic-compatibility.mdx | 406 ++++++++++++ docs/docs.json | 12 +- docs/integrations/claude-code.mdx | 69 ++ middleware/anthropic.go | 149 +++++ middleware/anthropic_test.go | 584 ++++++++++++++++ server/routes.go | 3 + 9 files changed, 2952 insertions(+), 3 deletions(-) create mode 100644 anthropic/anthropic.go create mode 100644 anthropic/anthropic_test.go create mode 100644 docs/api/anthropic-compatibility.mdx create mode 100644 docs/integrations/claude-code.mdx create mode 100644 middleware/anthropic.go create mode 100644 middleware/anthropic_test.go diff --git a/anthropic/anthropic.go b/anthropic/anthropic.go new file mode 100644 index 000000000..9cb2c75c4 --- /dev/null +++ b/anthropic/anthropic.go @@ -0,0 +1,778 @@ +package anthropic + +import ( + "crypto/rand" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "log/slog" + "net/http" + "strings" + "time" + + "github.com/ollama/ollama/api" +) + +// Error types matching Anthropic API +type Error struct { + Type string `json:"type"` + Message string `json:"message"` +} + +type ErrorResponse struct { + Type string `json:"type"` // always "error" + Error Error `json:"error"` + RequestID string `json:"request_id,omitempty"` +} + +// NewError creates a new ErrorResponse with the appropriate error type based on HTTP status code +func NewError(code int, message string) ErrorResponse { + var etype string + switch code { + case http.StatusBadRequest: + etype = "invalid_request_error" + case http.StatusUnauthorized: + etype = "authentication_error" + case http.StatusForbidden: + etype = "permission_error" + case http.StatusNotFound: + etype = "not_found_error" + case http.StatusTooManyRequests: + etype = "rate_limit_error" + case http.StatusServiceUnavailable, 529: + etype = "overloaded_error" + default: + etype = "api_error" + } + + return ErrorResponse{ + Type: "error", + Error: Error{Type: etype, Message: message}, + RequestID: generateID("req"), + } +} + +// Request types + +// MessagesRequest represents an Anthropic Messages API request +type MessagesRequest struct { + Model string `json:"model"` + MaxTokens int `json:"max_tokens"` + Messages []MessageParam `json:"messages"` + System any `json:"system,omitempty"` // string or []ContentBlock + Stream bool `json:"stream,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + TopK *int `json:"top_k,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + Tools []Tool `json:"tools,omitempty"` + ToolChoice *ToolChoice `json:"tool_choice,omitempty"` + Thinking *ThinkingConfig `json:"thinking,omitempty"` + Metadata *Metadata `json:"metadata,omitempty"` +} + +// MessageParam represents a message in the request +type MessageParam struct { + Role string `json:"role"` // "user" or "assistant" + Content any `json:"content"` // string or []ContentBlock +} + +// ContentBlock represents a content block in a message. +// Text and Thinking use pointers so they serialize as the field being present (even if empty) +// only when set, which is required for SDK streaming accumulation. +type ContentBlock struct { + Type string `json:"type"` // text, image, tool_use, tool_result, thinking + + // For text blocks - pointer so field only appears when set (SDK requires it for accumulation) + Text *string `json:"text,omitempty"` + + // For image blocks + Source *ImageSource `json:"source,omitempty"` + + // For tool_use blocks + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input any `json:"input,omitempty"` + + // For tool_result blocks + ToolUseID string `json:"tool_use_id,omitempty"` + Content any `json:"content,omitempty"` // string or []ContentBlock + IsError bool `json:"is_error,omitempty"` + + // For thinking blocks - pointer so field only appears when set (SDK requires it for accumulation) + Thinking *string `json:"thinking,omitempty"` + Signature string `json:"signature,omitempty"` +} + +// ImageSource represents the source of an image +type ImageSource struct { + Type string `json:"type"` // "base64" or "url" + MediaType string `json:"media_type,omitempty"` + Data string `json:"data,omitempty"` + URL string `json:"url,omitempty"` +} + +// Tool represents a tool definition +type Tool struct { + Type string `json:"type,omitempty"` // "custom" for user-defined tools + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema json.RawMessage `json:"input_schema,omitempty"` +} + +// ToolChoice controls how the model uses tools +type ToolChoice struct { + Type string `json:"type"` // "auto", "any", "tool", "none" + Name string `json:"name,omitempty"` + DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"` +} + +// ThinkingConfig controls extended thinking +type ThinkingConfig struct { + Type string `json:"type"` // "enabled" or "disabled" + BudgetTokens int `json:"budget_tokens,omitempty"` +} + +// Metadata for the request +type Metadata struct { + UserID string `json:"user_id,omitempty"` +} + +// Response types + +// MessagesResponse represents an Anthropic Messages API response +type MessagesResponse struct { + ID string `json:"id"` + Type string `json:"type"` // "message" + Role string `json:"role"` // "assistant" + Model string `json:"model"` + Content []ContentBlock `json:"content"` + StopReason string `json:"stop_reason,omitempty"` + StopSequence string `json:"stop_sequence,omitempty"` + Usage Usage `json:"usage"` +} + +// Usage contains token usage information +type Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} + +// Streaming event types + +// MessageStartEvent is sent at the start of streaming +type MessageStartEvent struct { + Type string `json:"type"` // "message_start" + Message MessagesResponse `json:"message"` +} + +// ContentBlockStartEvent signals the start of a content block +type ContentBlockStartEvent struct { + Type string `json:"type"` // "content_block_start" + Index int `json:"index"` + ContentBlock ContentBlock `json:"content_block"` +} + +// ContentBlockDeltaEvent contains incremental content updates +type ContentBlockDeltaEvent struct { + Type string `json:"type"` // "content_block_delta" + Index int `json:"index"` + Delta Delta `json:"delta"` +} + +// Delta represents an incremental update +type Delta struct { + Type string `json:"type"` // "text_delta", "input_json_delta", "thinking_delta", "signature_delta" + Text string `json:"text,omitempty"` + PartialJSON string `json:"partial_json,omitempty"` + Thinking string `json:"thinking,omitempty"` + Signature string `json:"signature,omitempty"` +} + +// ContentBlockStopEvent signals the end of a content block +type ContentBlockStopEvent struct { + Type string `json:"type"` // "content_block_stop" + Index int `json:"index"` +} + +// MessageDeltaEvent contains updates to the message +type MessageDeltaEvent struct { + Type string `json:"type"` // "message_delta" + Delta MessageDelta `json:"delta"` + Usage DeltaUsage `json:"usage"` +} + +// MessageDelta contains stop information +type MessageDelta struct { + StopReason string `json:"stop_reason,omitempty"` + StopSequence string `json:"stop_sequence,omitempty"` +} + +// DeltaUsage contains cumulative token usage +type DeltaUsage struct { + OutputTokens int `json:"output_tokens"` +} + +// MessageStopEvent signals the end of the message +type MessageStopEvent struct { + Type string `json:"type"` // "message_stop" +} + +// PingEvent is a keepalive event +type PingEvent struct { + Type string `json:"type"` // "ping" +} + +// StreamErrorEvent is an error during streaming +type StreamErrorEvent struct { + Type string `json:"type"` // "error" + Error Error `json:"error"` +} + +// FromMessagesRequest converts an Anthropic MessagesRequest to an Ollama api.ChatRequest +func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) { + var messages []api.Message + + if r.System != nil { + switch sys := r.System.(type) { + case string: + if sys != "" { + messages = append(messages, api.Message{Role: "system", Content: sys}) + } + case []any: + // System can be an array of content blocks + var content strings.Builder + for _, block := range sys { + if blockMap, ok := block.(map[string]any); ok { + if blockMap["type"] == "text" { + if text, ok := blockMap["text"].(string); ok { + content.WriteString(text) + } + } + } + } + if content.Len() > 0 { + messages = append(messages, api.Message{Role: "system", Content: content.String()}) + } + } + } + + for _, msg := range r.Messages { + converted, err := convertMessage(msg) + if err != nil { + return nil, err + } + messages = append(messages, converted...) + } + + options := make(map[string]any) + + options["num_predict"] = r.MaxTokens + + if r.Temperature != nil { + options["temperature"] = *r.Temperature + } + + if r.TopP != nil { + options["top_p"] = *r.TopP + } + + if r.TopK != nil { + options["top_k"] = *r.TopK + } + + if len(r.StopSequences) > 0 { + options["stop"] = r.StopSequences + } + + var tools api.Tools + for _, t := range r.Tools { + tool, err := convertTool(t) + if err != nil { + return nil, err + } + tools = append(tools, tool) + } + + var think *api.ThinkValue + if r.Thinking != nil && r.Thinking.Type == "enabled" { + think = &api.ThinkValue{Value: true} + } + + stream := r.Stream + + return &api.ChatRequest{ + Model: r.Model, + Messages: messages, + Options: options, + Stream: &stream, + Tools: tools, + Think: think, + }, nil +} + +// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s) +func convertMessage(msg MessageParam) ([]api.Message, error) { + var messages []api.Message + role := strings.ToLower(msg.Role) + + switch content := msg.Content.(type) { + case string: + messages = append(messages, api.Message{Role: role, Content: content}) + + case []any: + var textContent strings.Builder + var images []api.ImageData + var toolCalls []api.ToolCall + var thinking string + var toolResults []api.Message + + for _, block := range content { + blockMap, ok := block.(map[string]any) + if !ok { + return nil, errors.New("invalid content block format") + } + + blockType, _ := blockMap["type"].(string) + + switch blockType { + case "text": + if text, ok := blockMap["text"].(string); ok { + textContent.WriteString(text) + } + + case "image": + source, ok := blockMap["source"].(map[string]any) + if !ok { + return nil, errors.New("invalid image source") + } + + sourceType, _ := source["type"].(string) + if sourceType == "base64" { + data, _ := source["data"].(string) + decoded, err := base64.StdEncoding.DecodeString(data) + if err != nil { + return nil, fmt.Errorf("invalid base64 image data: %w", err) + } + images = append(images, decoded) + } else { + return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType) + } + // URL images would need to be fetched - skip for now + + case "tool_use": + id, ok := blockMap["id"].(string) + if !ok { + return nil, errors.New("tool_use block missing required 'id' field") + } + name, ok := blockMap["name"].(string) + if !ok { + return nil, errors.New("tool_use block missing required 'name' field") + } + tc := api.ToolCall{ + ID: id, + Function: api.ToolCallFunction{ + Name: name, + }, + } + if input, ok := blockMap["input"].(map[string]any); ok { + tc.Function.Arguments = mapToArgs(input) + } + toolCalls = append(toolCalls, tc) + + case "tool_result": + toolUseID, _ := blockMap["tool_use_id"].(string) + var resultContent string + + switch c := blockMap["content"].(type) { + case string: + resultContent = c + case []any: + for _, cb := range c { + if cbMap, ok := cb.(map[string]any); ok { + if cbMap["type"] == "text" { + if text, ok := cbMap["text"].(string); ok { + resultContent += text + } + } + } + } + } + + toolResults = append(toolResults, api.Message{ + Role: "tool", + Content: resultContent, + ToolCallID: toolUseID, + }) + + case "thinking": + if t, ok := blockMap["thinking"].(string); ok { + thinking = t + } + } + } + + if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" { + m := api.Message{ + Role: role, + Content: textContent.String(), + Images: images, + ToolCalls: toolCalls, + Thinking: thinking, + } + messages = append(messages, m) + } + + // Add tool results as separate messages + messages = append(messages, toolResults...) + + default: + return nil, fmt.Errorf("invalid message content type: %T", content) + } + + return messages, nil +} + +// convertTool converts an Anthropic Tool to an Ollama api.Tool +func convertTool(t Tool) (api.Tool, error) { + var params api.ToolFunctionParameters + if len(t.InputSchema) > 0 { + if err := json.Unmarshal(t.InputSchema, ¶ms); err != nil { + return api.Tool{}, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err) + } + } + + return api.Tool{ + Type: "function", + Function: api.ToolFunction{ + Name: t.Name, + Description: t.Description, + Parameters: params, + }, + }, nil +} + +// ToMessagesResponse converts an Ollama api.ChatResponse to an Anthropic MessagesResponse +func ToMessagesResponse(id string, r api.ChatResponse) MessagesResponse { + var content []ContentBlock + + if r.Message.Thinking != "" { + content = append(content, ContentBlock{ + Type: "thinking", + Thinking: ptr(r.Message.Thinking), + }) + } + + if r.Message.Content != "" { + content = append(content, ContentBlock{ + Type: "text", + Text: ptr(r.Message.Content), + }) + } + + for _, tc := range r.Message.ToolCalls { + content = append(content, ContentBlock{ + Type: "tool_use", + ID: tc.ID, + Name: tc.Function.Name, + Input: tc.Function.Arguments, + }) + } + + stopReason := mapStopReason(r.DoneReason, len(r.Message.ToolCalls) > 0) + + return MessagesResponse{ + ID: id, + Type: "message", + Role: "assistant", + Model: r.Model, + Content: content, + StopReason: stopReason, + Usage: Usage{ + InputTokens: r.Metrics.PromptEvalCount, + OutputTokens: r.Metrics.EvalCount, + }, + } +} + +// mapStopReason converts Ollama done_reason to Anthropic stop_reason +func mapStopReason(reason string, hasToolCalls bool) string { + if hasToolCalls { + return "tool_use" + } + + switch reason { + case "stop": + return "end_turn" + case "length": + return "max_tokens" + default: + if reason != "" { + return "stop_sequence" + } + return "" + } +} + +// StreamConverter manages state for converting Ollama streaming responses to Anthropic format +type StreamConverter struct { + ID string + Model string + firstWrite bool + contentIndex int + inputTokens int + outputTokens int + thinkingStarted bool + thinkingDone bool + textStarted bool + toolCallsSent map[string]bool +} + +func NewStreamConverter(id, model string) *StreamConverter { + return &StreamConverter{ + ID: id, + Model: model, + firstWrite: true, + toolCallsSent: make(map[string]bool), + } +} + +// StreamEvent represents a streaming event to be sent to the client +type StreamEvent struct { + Event string + Data any +} + +// Process converts an Ollama ChatResponse to Anthropic streaming events +func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent { + var events []StreamEvent + + if c.firstWrite { + c.firstWrite = false + c.inputTokens = r.Metrics.PromptEvalCount + + events = append(events, StreamEvent{ + Event: "message_start", + Data: MessageStartEvent{ + Type: "message_start", + Message: MessagesResponse{ + ID: c.ID, + Type: "message", + Role: "assistant", + Model: c.Model, + Content: []ContentBlock{}, + Usage: Usage{ + InputTokens: c.inputTokens, + OutputTokens: 0, + }, + }, + }, + }) + } + + if r.Message.Thinking != "" && !c.thinkingDone { + if !c.thinkingStarted { + c.thinkingStarted = true + events = append(events, StreamEvent{ + Event: "content_block_start", + Data: ContentBlockStartEvent{ + Type: "content_block_start", + Index: c.contentIndex, + ContentBlock: ContentBlock{ + Type: "thinking", + Thinking: ptr(""), + }, + }, + }) + } + + events = append(events, StreamEvent{ + Event: "content_block_delta", + Data: ContentBlockDeltaEvent{ + Type: "content_block_delta", + Index: c.contentIndex, + Delta: Delta{ + Type: "thinking_delta", + Thinking: r.Message.Thinking, + }, + }, + }) + } + + if r.Message.Content != "" { + if c.thinkingStarted && !c.thinkingDone { + c.thinkingDone = true + events = append(events, StreamEvent{ + Event: "content_block_stop", + Data: ContentBlockStopEvent{ + Type: "content_block_stop", + Index: c.contentIndex, + }, + }) + c.contentIndex++ + } + + if !c.textStarted { + c.textStarted = true + events = append(events, StreamEvent{ + Event: "content_block_start", + Data: ContentBlockStartEvent{ + Type: "content_block_start", + Index: c.contentIndex, + ContentBlock: ContentBlock{ + Type: "text", + Text: ptr(""), + }, + }, + }) + } + + events = append(events, StreamEvent{ + Event: "content_block_delta", + Data: ContentBlockDeltaEvent{ + Type: "content_block_delta", + Index: c.contentIndex, + Delta: Delta{ + Type: "text_delta", + Text: r.Message.Content, + }, + }, + }) + } + + for _, tc := range r.Message.ToolCalls { + if c.toolCallsSent[tc.ID] { + continue + } + + if c.textStarted { + events = append(events, StreamEvent{ + Event: "content_block_stop", + Data: ContentBlockStopEvent{ + Type: "content_block_stop", + Index: c.contentIndex, + }, + }) + c.contentIndex++ + c.textStarted = false + } + + argsJSON, err := json.Marshal(tc.Function.Arguments) + if err != nil { + slog.Error("failed to marshal tool arguments", "error", err, "tool_id", tc.ID) + continue + } + + events = append(events, StreamEvent{ + Event: "content_block_start", + Data: ContentBlockStartEvent{ + Type: "content_block_start", + Index: c.contentIndex, + ContentBlock: ContentBlock{ + Type: "tool_use", + ID: tc.ID, + Name: tc.Function.Name, + Input: map[string]any{}, + }, + }, + }) + + events = append(events, StreamEvent{ + Event: "content_block_delta", + Data: ContentBlockDeltaEvent{ + Type: "content_block_delta", + Index: c.contentIndex, + Delta: Delta{ + Type: "input_json_delta", + PartialJSON: string(argsJSON), + }, + }, + }) + + events = append(events, StreamEvent{ + Event: "content_block_stop", + Data: ContentBlockStopEvent{ + Type: "content_block_stop", + Index: c.contentIndex, + }, + }) + + c.toolCallsSent[tc.ID] = true + c.contentIndex++ + } + + if r.Done { + if c.textStarted { + events = append(events, StreamEvent{ + Event: "content_block_stop", + Data: ContentBlockStopEvent{ + Type: "content_block_stop", + Index: c.contentIndex, + }, + }) + } else if c.thinkingStarted && !c.thinkingDone { + events = append(events, StreamEvent{ + Event: "content_block_stop", + Data: ContentBlockStopEvent{ + Type: "content_block_stop", + Index: c.contentIndex, + }, + }) + } + + c.outputTokens = r.Metrics.EvalCount + stopReason := mapStopReason(r.DoneReason, len(c.toolCallsSent) > 0) + + events = append(events, StreamEvent{ + Event: "message_delta", + Data: MessageDeltaEvent{ + Type: "message_delta", + Delta: MessageDelta{ + StopReason: stopReason, + }, + Usage: DeltaUsage{ + OutputTokens: c.outputTokens, + }, + }, + }) + + events = append(events, StreamEvent{ + Event: "message_stop", + Data: MessageStopEvent{ + Type: "message_stop", + }, + }) + } + + return events +} + +// generateID generates a unique ID with the given prefix using crypto/rand +func generateID(prefix string) string { + b := make([]byte, 12) + if _, err := rand.Read(b); err != nil { + // Fallback to time-based ID if crypto/rand fails + return fmt.Sprintf("%s_%d", prefix, time.Now().UnixNano()) + } + return fmt.Sprintf("%s_%x", prefix, b) +} + +// GenerateMessageID generates a unique message ID +func GenerateMessageID() string { + return generateID("msg") +} + +// ptr returns a pointer to the given string value +func ptr(s string) *string { + return &s +} + +// mapToArgs converts a map to ToolCallFunctionArguments +func mapToArgs(m map[string]any) api.ToolCallFunctionArguments { + args := api.NewToolCallFunctionArguments() + for k, v := range m { + args.Set(k, v) + } + return args +} diff --git a/anthropic/anthropic_test.go b/anthropic/anthropic_test.go new file mode 100644 index 000000000..117d183c9 --- /dev/null +++ b/anthropic/anthropic_test.go @@ -0,0 +1,953 @@ +package anthropic + +import ( + "encoding/base64" + "encoding/json" + "testing" + + "github.com/google/go-cmp/cmp" + + "github.com/ollama/ollama/api" +) + +const ( + testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` +) + +// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests) +func testArgs(m map[string]any) api.ToolCallFunctionArguments { + args := api.NewToolCallFunctionArguments() + for k, v := range m { + args.Set(k, v) + } + return args +} + +func TestFromMessagesRequest_Basic(t *testing.T) { + req := MessagesRequest{ + Model: "test-model", + MaxTokens: 1024, + Messages: []MessageParam{ + {Role: "user", Content: "Hello"}, + }, + } + + result, err := FromMessagesRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Model != "test-model" { + t.Errorf("expected model 'test-model', got %q", result.Model) + } + + if len(result.Messages) != 1 { + t.Fatalf("expected 1 message, got %d", len(result.Messages)) + } + + if result.Messages[0].Role != "user" || result.Messages[0].Content != "Hello" { + t.Errorf("unexpected message: %+v", result.Messages[0]) + } + + if numPredict, ok := result.Options["num_predict"].(int); !ok || numPredict != 1024 { + t.Errorf("expected num_predict 1024, got %v", result.Options["num_predict"]) + } +} + +func TestFromMessagesRequest_WithSystemPrompt(t *testing.T) { + req := MessagesRequest{ + Model: "test-model", + MaxTokens: 1024, + System: "You are a helpful assistant.", + Messages: []MessageParam{ + {Role: "user", Content: "Hello"}, + }, + } + + result, err := FromMessagesRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result.Messages) != 2 { + t.Fatalf("expected 2 messages, got %d", len(result.Messages)) + } + + if result.Messages[0].Role != "system" || result.Messages[0].Content != "You are a helpful assistant." { + t.Errorf("unexpected system message: %+v", result.Messages[0]) + } +} + +func TestFromMessagesRequest_WithSystemPromptArray(t *testing.T) { + req := MessagesRequest{ + Model: "test-model", + MaxTokens: 1024, + System: []any{ + map[string]any{"type": "text", "text": "You are helpful."}, + map[string]any{"type": "text", "text": " Be concise."}, + }, + Messages: []MessageParam{ + {Role: "user", Content: "Hello"}, + }, + } + + result, err := FromMessagesRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result.Messages) != 2 { + t.Fatalf("expected 2 messages, got %d", len(result.Messages)) + } + + if result.Messages[0].Content != "You are helpful. Be concise." { + t.Errorf("unexpected system message content: %q", result.Messages[0].Content) + } +} + +func TestFromMessagesRequest_WithOptions(t *testing.T) { + temp := 0.7 + topP := 0.9 + topK := 40 + req := MessagesRequest{ + Model: "test-model", + MaxTokens: 2048, + Messages: []MessageParam{{Role: "user", Content: "Hello"}}, + Temperature: &temp, + TopP: &topP, + TopK: &topK, + StopSequences: []string{"\n", "END"}, + } + + result, err := FromMessagesRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Options["temperature"] != 0.7 { + t.Errorf("expected temperature 0.7, got %v", result.Options["temperature"]) + } + if result.Options["top_p"] != 0.9 { + t.Errorf("expected top_p 0.9, got %v", result.Options["top_p"]) + } + if result.Options["top_k"] != 40 { + t.Errorf("expected top_k 40, got %v", result.Options["top_k"]) + } + if diff := cmp.Diff([]string{"\n", "END"}, result.Options["stop"]); diff != "" { + t.Errorf("stop sequences mismatch: %s", diff) + } +} + +func TestFromMessagesRequest_WithImage(t *testing.T) { + imgData, _ := base64.StdEncoding.DecodeString(testImage) + + req := MessagesRequest{ + Model: "test-model", + MaxTokens: 1024, + Messages: []MessageParam{ + { + Role: "user", + Content: []any{ + map[string]any{"type": "text", "text": "What's in this image?"}, + map[string]any{ + "type": "image", + "source": map[string]any{ + "type": "base64", + "media_type": "image/png", + "data": testImage, + }, + }, + }, + }, + }, + } + + result, err := FromMessagesRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result.Messages) != 1 { + t.Fatalf("expected 1 message, got %d", len(result.Messages)) + } + + if result.Messages[0].Content != "What's in this image?" { + t.Errorf("expected content 'What's in this image?', got %q", result.Messages[0].Content) + } + + if len(result.Messages[0].Images) != 1 { + t.Fatalf("expected 1 image, got %d", len(result.Messages[0].Images)) + } + + if string(result.Messages[0].Images[0]) != string(imgData) { + t.Error("image data mismatch") + } +} + +func TestFromMessagesRequest_WithToolUse(t *testing.T) { + req := MessagesRequest{ + Model: "test-model", + MaxTokens: 1024, + Messages: []MessageParam{ + {Role: "user", Content: "What's the weather in Paris?"}, + { + Role: "assistant", + Content: []any{ + map[string]any{ + "type": "tool_use", + "id": "call_123", + "name": "get_weather", + "input": map[string]any{"location": "Paris"}, + }, + }, + }, + }, + } + + result, err := FromMessagesRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result.Messages) != 2 { + t.Fatalf("expected 2 messages, got %d", len(result.Messages)) + } + + if len(result.Messages[1].ToolCalls) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(result.Messages[1].ToolCalls)) + } + + tc := result.Messages[1].ToolCalls[0] + if tc.ID != "call_123" { + t.Errorf("expected tool call ID 'call_123', got %q", tc.ID) + } + if tc.Function.Name != "get_weather" { + t.Errorf("expected tool name 'get_weather', got %q", tc.Function.Name) + } +} + +func TestFromMessagesRequest_WithToolResult(t *testing.T) { + req := MessagesRequest{ + Model: "test-model", + MaxTokens: 1024, + Messages: []MessageParam{ + { + Role: "user", + Content: []any{ + map[string]any{ + "type": "tool_result", + "tool_use_id": "call_123", + "content": "The weather in Paris is sunny, 22°C", + }, + }, + }, + }, + } + + result, err := FromMessagesRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result.Messages) != 1 { + t.Fatalf("expected 1 message, got %d", len(result.Messages)) + } + + msg := result.Messages[0] + if msg.Role != "tool" { + t.Errorf("expected role 'tool', got %q", msg.Role) + } + if msg.ToolCallID != "call_123" { + t.Errorf("expected tool_call_id 'call_123', got %q", msg.ToolCallID) + } + if msg.Content != "The weather in Paris is sunny, 22°C" { + t.Errorf("unexpected content: %q", msg.Content) + } +} + +func TestFromMessagesRequest_WithTools(t *testing.T) { + req := MessagesRequest{ + Model: "test-model", + MaxTokens: 1024, + Messages: []MessageParam{{Role: "user", Content: "Hello"}}, + Tools: []Tool{ + { + Name: "get_weather", + Description: "Get current weather", + InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`), + }, + }, + } + + result, err := FromMessagesRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result.Tools) != 1 { + t.Fatalf("expected 1 tool, got %d", len(result.Tools)) + } + + tool := result.Tools[0] + if tool.Type != "function" { + t.Errorf("expected type 'function', got %q", tool.Type) + } + if tool.Function.Name != "get_weather" { + t.Errorf("expected name 'get_weather', got %q", tool.Function.Name) + } + if tool.Function.Description != "Get current weather" { + t.Errorf("expected description 'Get current weather', got %q", tool.Function.Description) + } +} + +func TestFromMessagesRequest_WithThinking(t *testing.T) { + req := MessagesRequest{ + Model: "test-model", + MaxTokens: 1024, + Messages: []MessageParam{{Role: "user", Content: "Hello"}}, + Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000}, + } + + result, err := FromMessagesRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Think == nil { + t.Fatal("expected Think to be set") + } + if v, ok := result.Think.Value.(bool); !ok || !v { + t.Errorf("expected Think.Value to be true, got %v", result.Think.Value) + } +} + +// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only +// a thinking block (no text, images, or tool calls) are preserved and not dropped. +func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) { + req := MessagesRequest{ + Model: "test-model", + MaxTokens: 1024, + Messages: []MessageParam{ + {Role: "user", Content: "Hello"}, + { + Role: "assistant", + Content: []any{ + map[string]any{ + "type": "thinking", + "thinking": "Let me think about this...", + }, + }, + }, + }, + } + + result, err := FromMessagesRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result.Messages) != 2 { + t.Fatalf("expected 2 messages, got %d", len(result.Messages)) + } + + assistantMsg := result.Messages[1] + if assistantMsg.Thinking != "Let me think about this..." { + t.Errorf("expected thinking content, got %q", assistantMsg.Thinking) + } +} + +func TestFromMessagesRequest_ToolUseMissingID(t *testing.T) { + req := MessagesRequest{ + Model: "test-model", + MaxTokens: 1024, + Messages: []MessageParam{ + { + Role: "assistant", + Content: []any{ + map[string]any{ + "type": "tool_use", + "name": "get_weather", + }, + }, + }, + }, + } + + _, err := FromMessagesRequest(req) + if err == nil { + t.Fatal("expected error for missing tool_use id") + } + if err.Error() != "tool_use block missing required 'id' field" { + t.Errorf("unexpected error message: %v", err) + } +} + +func TestFromMessagesRequest_ToolUseMissingName(t *testing.T) { + req := MessagesRequest{ + Model: "test-model", + MaxTokens: 1024, + Messages: []MessageParam{ + { + Role: "assistant", + Content: []any{ + map[string]any{ + "type": "tool_use", + "id": "call_123", + }, + }, + }, + }, + } + + _, err := FromMessagesRequest(req) + if err == nil { + t.Fatal("expected error for missing tool_use name") + } + if err.Error() != "tool_use block missing required 'name' field" { + t.Errorf("unexpected error message: %v", err) + } +} + +func TestFromMessagesRequest_InvalidToolSchema(t *testing.T) { + req := MessagesRequest{ + Model: "test-model", + MaxTokens: 1024, + Messages: []MessageParam{{Role: "user", Content: "Hello"}}, + Tools: []Tool{ + { + Name: "bad_tool", + InputSchema: json.RawMessage(`{invalid json`), + }, + }, + } + + _, err := FromMessagesRequest(req) + if err == nil { + t.Fatal("expected error for invalid tool schema") + } +} + +func TestToMessagesResponse_Basic(t *testing.T) { + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + Content: "Hello there!", + }, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{ + PromptEvalCount: 10, + EvalCount: 5, + }, + } + + result := ToMessagesResponse("msg_123", resp) + + if result.ID != "msg_123" { + t.Errorf("expected ID 'msg_123', got %q", result.ID) + } + if result.Type != "message" { + t.Errorf("expected type 'message', got %q", result.Type) + } + if result.Role != "assistant" { + t.Errorf("expected role 'assistant', got %q", result.Role) + } + if len(result.Content) != 1 { + t.Fatalf("expected 1 content block, got %d", len(result.Content)) + } + if result.Content[0].Type != "text" || result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" { + t.Errorf("unexpected content: %+v", result.Content[0]) + } + if result.StopReason != "end_turn" { + t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason) + } + if result.Usage.InputTokens != 10 || result.Usage.OutputTokens != 5 { + t.Errorf("unexpected usage: %+v", result.Usage) + } +} + +func TestToMessagesResponse_WithToolCalls(t *testing.T) { + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + ID: "call_123", + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: testArgs(map[string]any{"location": "Paris"}), + }, + }, + }, + }, + Done: true, + DoneReason: "stop", + } + + result := ToMessagesResponse("msg_123", resp) + + if len(result.Content) != 1 { + t.Fatalf("expected 1 content block, got %d", len(result.Content)) + } + if result.Content[0].Type != "tool_use" { + t.Errorf("expected type 'tool_use', got %q", result.Content[0].Type) + } + if result.Content[0].ID != "call_123" { + t.Errorf("expected ID 'call_123', got %q", result.Content[0].ID) + } + if result.Content[0].Name != "get_weather" { + t.Errorf("expected name 'get_weather', got %q", result.Content[0].Name) + } + if result.StopReason != "tool_use" { + t.Errorf("expected stop_reason 'tool_use', got %q", result.StopReason) + } +} + +func TestToMessagesResponse_WithThinking(t *testing.T) { + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + Content: "The answer is 42.", + Thinking: "Let me think about this...", + }, + Done: true, + DoneReason: "stop", + } + + result := ToMessagesResponse("msg_123", resp) + + if len(result.Content) != 2 { + t.Fatalf("expected 2 content blocks, got %d", len(result.Content)) + } + if result.Content[0].Type != "thinking" { + t.Errorf("expected first block type 'thinking', got %q", result.Content[0].Type) + } + if result.Content[0].Thinking == nil || *result.Content[0].Thinking != "Let me think about this..." { + t.Errorf("unexpected thinking content: %v", result.Content[0].Thinking) + } + if result.Content[1].Type != "text" { + t.Errorf("expected second block type 'text', got %q", result.Content[1].Type) + } +} + +func TestMapStopReason(t *testing.T) { + tests := []struct { + reason string + hasToolCalls bool + want string + }{ + {"stop", false, "end_turn"}, + {"length", false, "max_tokens"}, + {"stop", true, "tool_use"}, + {"other", false, "stop_sequence"}, + {"", false, ""}, + } + + for _, tt := range tests { + got := mapStopReason(tt.reason, tt.hasToolCalls) + if got != tt.want { + t.Errorf("mapStopReason(%q, %v) = %q, want %q", tt.reason, tt.hasToolCalls, got, tt.want) + } + } +} + +func TestNewError(t *testing.T) { + tests := []struct { + code int + want string + }{ + {400, "invalid_request_error"}, + {401, "authentication_error"}, + {403, "permission_error"}, + {404, "not_found_error"}, + {429, "rate_limit_error"}, + {500, "api_error"}, + {503, "overloaded_error"}, + {529, "overloaded_error"}, + } + + for _, tt := range tests { + result := NewError(tt.code, "test message") + if result.Type != "error" { + t.Errorf("NewError(%d) type = %q, want 'error'", tt.code, result.Type) + } + if result.Error.Type != tt.want { + t.Errorf("NewError(%d) error.type = %q, want %q", tt.code, result.Error.Type, tt.want) + } + if result.Error.Message != "test message" { + t.Errorf("NewError(%d) message = %q, want 'test message'", tt.code, result.Error.Message) + } + if result.RequestID == "" { + t.Errorf("NewError(%d) request_id should not be empty", tt.code) + } + } +} + +func TestGenerateMessageID(t *testing.T) { + id1 := GenerateMessageID() + id2 := GenerateMessageID() + + if id1 == "" { + t.Error("GenerateMessageID returned empty string") + } + if id1 == id2 { + t.Error("GenerateMessageID returned duplicate IDs") + } + if len(id1) < 10 { + t.Errorf("GenerateMessageID returned short ID: %q", id1) + } + if id1[:4] != "msg_" { + t.Errorf("GenerateMessageID should start with 'msg_', got %q", id1[:4]) + } +} + +func TestStreamConverter_Basic(t *testing.T) { + conv := NewStreamConverter("msg_123", "test-model") + + // First chunk + resp1 := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + Content: "Hello", + }, + Metrics: api.Metrics{PromptEvalCount: 10}, + } + + events1 := conv.Process(resp1) + if len(events1) < 3 { + t.Fatalf("expected at least 3 events for first chunk, got %d", len(events1)) + } + + // Should have message_start, content_block_start, content_block_delta + if events1[0].Event != "message_start" { + t.Errorf("expected first event 'message_start', got %q", events1[0].Event) + } + if events1[1].Event != "content_block_start" { + t.Errorf("expected second event 'content_block_start', got %q", events1[1].Event) + } + if events1[2].Event != "content_block_delta" { + t.Errorf("expected third event 'content_block_delta', got %q", events1[2].Event) + } + + // Final chunk + resp2 := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + Content: " world!", + }, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{EvalCount: 5}, + } + + events2 := conv.Process(resp2) + + // Should have content_block_delta, content_block_stop, message_delta, message_stop + hasStop := false + for _, e := range events2 { + if e.Event == "message_stop" { + hasStop = true + } + } + if !hasStop { + t.Error("expected message_stop event in final chunk") + } +} + +func TestStreamConverter_WithToolCalls(t *testing.T) { + conv := NewStreamConverter("msg_123", "test-model") + + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + ID: "call_123", + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: testArgs(map[string]any{"location": "Paris"}), + }, + }, + }, + }, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5}, + } + + events := conv.Process(resp) + + hasToolStart := false + hasToolDelta := false + for _, e := range events { + if e.Event == "content_block_start" { + if start, ok := e.Data.(ContentBlockStartEvent); ok { + if start.ContentBlock.Type == "tool_use" { + hasToolStart = true + } + } + } + if e.Event == "content_block_delta" { + if delta, ok := e.Data.(ContentBlockDeltaEvent); ok { + if delta.Delta.Type == "input_json_delta" { + hasToolDelta = true + } + } + } + } + + if !hasToolStart { + t.Error("expected tool_use content_block_start event") + } + if !hasToolDelta { + t.Error("expected input_json_delta event") + } +} + +func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) { + // Test that unmarshalable arguments (like channels) are handled gracefully + // and don't cause a panic or corrupt stream + conv := NewStreamConverter("msg_123", "test-model") + + // Create a channel which cannot be JSON marshaled + unmarshalable := make(chan int) + badArgs := api.NewToolCallFunctionArguments() + badArgs.Set("channel", unmarshalable) + + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + ID: "call_bad", + Function: api.ToolCallFunction{ + Name: "bad_function", + Arguments: badArgs, + }, + }, + }, + }, + Done: true, + DoneReason: "stop", + } + + // Should not panic and should skip the unmarshalable tool call + events := conv.Process(resp) + + // Verify no tool_use block was started (since marshal failed before block start) + hasToolStart := false + for _, e := range events { + if e.Event == "content_block_start" { + if start, ok := e.Data.(ContentBlockStartEvent); ok { + if start.ContentBlock.Type == "tool_use" { + hasToolStart = true + } + } + } + } + + if hasToolStart { + t.Error("expected no tool_use block when arguments cannot be marshaled") + } +} + +func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) { + // Test that valid tool calls still work when mixed with invalid ones + conv := NewStreamConverter("msg_123", "test-model") + + unmarshalable := make(chan int) + badArgs := api.NewToolCallFunctionArguments() + badArgs.Set("channel", unmarshalable) + + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + ID: "call_good", + Function: api.ToolCallFunction{ + Name: "good_function", + Arguments: testArgs(map[string]any{"location": "Paris"}), + }, + }, + { + ID: "call_bad", + Function: api.ToolCallFunction{ + Name: "bad_function", + Arguments: badArgs, + }, + }, + }, + }, + Done: true, + DoneReason: "stop", + } + + events := conv.Process(resp) + + // Count tool_use blocks - should only have 1 (the valid one) + toolStartCount := 0 + toolDeltaCount := 0 + for _, e := range events { + if e.Event == "content_block_start" { + if start, ok := e.Data.(ContentBlockStartEvent); ok { + if start.ContentBlock.Type == "tool_use" { + toolStartCount++ + if start.ContentBlock.Name != "good_function" { + t.Errorf("expected tool name 'good_function', got %q", start.ContentBlock.Name) + } + } + } + } + if e.Event == "content_block_delta" { + if delta, ok := e.Data.(ContentBlockDeltaEvent); ok { + if delta.Delta.Type == "input_json_delta" { + toolDeltaCount++ + } + } + } + } + + if toolStartCount != 1 { + t.Errorf("expected 1 tool_use block, got %d", toolStartCount) + } + if toolDeltaCount != 1 { + t.Errorf("expected 1 input_json_delta, got %d", toolDeltaCount) + } +} + +// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields +// are serialized in JSON output. The Anthropic SDK requires these fields to be present +// (even when empty) in content_block_start events to properly accumulate streaming deltas. +// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'" +func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) { + tests := []struct { + name string + block ContentBlock + wantKeys []string + }{ + { + name: "text block includes empty text field", + block: ContentBlock{ + Type: "text", + Text: ptr(""), + }, + wantKeys: []string{"type", "text"}, + }, + { + name: "thinking block includes empty thinking field", + block: ContentBlock{ + Type: "thinking", + Thinking: ptr(""), + }, + wantKeys: []string{"type", "thinking"}, + }, + { + name: "text block with content", + block: ContentBlock{ + Type: "text", + Text: ptr("hello"), + }, + wantKeys: []string{"type", "text"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(tt.block) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var result map[string]any + if err := json.Unmarshal(data, &result); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + for _, key := range tt.wantKeys { + if _, ok := result[key]; !ok { + t.Errorf("expected key %q to be present in JSON output, got: %s", key, string(data)) + } + } + }) + } +} + +// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start +// events include the required empty fields for SDK compatibility. +func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) { + t.Run("text block start includes empty text", func(t *testing.T) { + conv := NewStreamConverter("msg_123", "test-model") + + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{Role: "assistant", Content: "hello"}, + } + + events := conv.Process(resp) + + var foundTextStart bool + for _, e := range events { + if e.Event == "content_block_start" { + if start, ok := e.Data.(ContentBlockStartEvent); ok { + if start.ContentBlock.Type == "text" { + foundTextStart = true + // Marshal and verify the text field is present + data, _ := json.Marshal(start) + var result map[string]any + json.Unmarshal(data, &result) + cb := result["content_block"].(map[string]any) + if _, ok := cb["text"]; !ok { + t.Error("content_block_start for text should include 'text' field") + } + } + } + } + } + + if !foundTextStart { + t.Error("expected text content_block_start event") + } + }) + + t.Run("thinking block start includes empty thinking", func(t *testing.T) { + conv := NewStreamConverter("msg_123", "test-model") + + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{Role: "assistant", Thinking: "let me think..."}, + } + + events := conv.Process(resp) + + var foundThinkingStart bool + for _, e := range events { + if e.Event == "content_block_start" { + if start, ok := e.Data.(ContentBlockStartEvent); ok { + if start.ContentBlock.Type == "thinking" { + foundThinkingStart = true + data, _ := json.Marshal(start) + var result map[string]any + json.Unmarshal(data, &result) + cb := result["content_block"].(map[string]any) + if _, ok := cb["thinking"]; !ok { + t.Error("content_block_start for thinking should include 'thinking' field") + } + } + } + } + } + + if !foundThinkingStart { + t.Error("expected thinking content_block_start event") + } + }) +} diff --git a/docs/README.md b/docs/README.md index 74544a321..4483eb550 100644 --- a/docs/README.md +++ b/docs/README.md @@ -14,6 +14,7 @@ * [API Reference](https://docs.ollama.com/api) * [Modelfile Reference](https://docs.ollama.com/modelfile) * [OpenAI Compatibility](https://docs.ollama.com/api/openai-compatibility) +* [Anthropic Compatibility](./api/anthropic-compatibility.mdx) ### Resources diff --git a/docs/api/anthropic-compatibility.mdx b/docs/api/anthropic-compatibility.mdx new file mode 100644 index 000000000..a0f2cd7fd --- /dev/null +++ b/docs/api/anthropic-compatibility.mdx @@ -0,0 +1,406 @@ +--- +title: Anthropic compatibility +--- + +Ollama provides compatibility with the [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) to help connect existing applications to Ollama, including tools like Claude Code. + +## Recommended models + +For coding use cases, models like `glm-4.7:cloud`, `minimax-m2.1:cloud`, and `qwen3-coder` are recommended. + +Pull a model before use: +```shell +ollama pull qwen3-coder +ollama pull glm-4.7:cloud +``` + +## Usage + +### Environment variables + +To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables: + +```shell +export ANTHROPIC_BASE_URL=http://localhost:11434 +export ANTHROPIC_API_KEY=ollama # required but ignored +``` + +### Simple `/v1/messages` example + + + +```python basic.py +import anthropic + +client = anthropic.Anthropic( + base_url='http://localhost:11434', + api_key='ollama', # required but ignored +) + +message = client.messages.create( + model='qwen3-coder', + max_tokens=1024, + messages=[ + {'role': 'user', 'content': 'Hello, how are you?'} + ] +) +print(message.content[0].text) +``` + +```javascript basic.js +import Anthropic from "@anthropic-ai/sdk"; + +const anthropic = new Anthropic({ + baseURL: "http://localhost:11434", + apiKey: "ollama", // required but ignored +}); + +const message = await anthropic.messages.create({ + model: "qwen3-coder", + max_tokens: 1024, + messages: [{ role: "user", content: "Hello, how are you?" }], +}); + +console.log(message.content[0].text); +``` + +```shell basic.sh +curl -X POST http://localhost:11434/v1/messages \ +-H "Content-Type: application/json" \ +-H "x-api-key: ollama" \ +-H "anthropic-version: 2023-06-01" \ +-d '{ + "model": "qwen3-coder", + "max_tokens": 1024, + "messages": [{ "role": "user", "content": "Hello, how are you?" }] +}' +``` + + + +### Streaming example + + + +```python streaming.py +import anthropic + +client = anthropic.Anthropic( + base_url='http://localhost:11434', + api_key='ollama', +) + +with client.messages.stream( + model='qwen3-coder', + max_tokens=1024, + messages=[{'role': 'user', 'content': 'Count from 1 to 10'}] +) as stream: + for text in stream.text_stream: + print(text, end='', flush=True) +``` + +```javascript streaming.js +import Anthropic from "@anthropic-ai/sdk"; + +const anthropic = new Anthropic({ + baseURL: "http://localhost:11434", + apiKey: "ollama", +}); + +const stream = await anthropic.messages.stream({ + model: "qwen3-coder", + max_tokens: 1024, + messages: [{ role: "user", content: "Count from 1 to 10" }], +}); + +for await (const event of stream) { + if ( + event.type === "content_block_delta" && + event.delta.type === "text_delta" + ) { + process.stdout.write(event.delta.text); + } +} +``` + +```shell streaming.sh +curl -X POST http://localhost:11434/v1/messages \ +-H "Content-Type: application/json" \ +-d '{ + "model": "qwen3-coder", + "max_tokens": 1024, + "stream": true, + "messages": [{ "role": "user", "content": "Count from 1 to 10" }] +}' +``` + + + +### Tool calling example + + + +```python tools.py +import anthropic + +client = anthropic.Anthropic( + base_url='http://localhost:11434', + api_key='ollama', +) + +message = client.messages.create( + model='qwen3-coder', + max_tokens=1024, + tools=[ + { + 'name': 'get_weather', + 'description': 'Get the current weather in a location', + 'input_schema': { + 'type': 'object', + 'properties': { + 'location': { + 'type': 'string', + 'description': 'The city and state, e.g. San Francisco, CA' + } + }, + 'required': ['location'] + } + } + ], + messages=[{'role': 'user', 'content': "What's the weather in San Francisco?"}] +) + +for block in message.content: + if block.type == 'tool_use': + print(f'Tool: {block.name}') + print(f'Input: {block.input}') +``` + +```javascript tools.js +import Anthropic from "@anthropic-ai/sdk"; + +const anthropic = new Anthropic({ + baseURL: "http://localhost:11434", + apiKey: "ollama", +}); + +const message = await anthropic.messages.create({ + model: "qwen3-coder", + max_tokens: 1024, + tools: [ + { + name: "get_weather", + description: "Get the current weather in a location", + input_schema: { + type: "object", + properties: { + location: { + type: "string", + description: "The city and state, e.g. San Francisco, CA", + }, + }, + required: ["location"], + }, + }, + ], + messages: [{ role: "user", content: "What's the weather in San Francisco?" }], +}); + +for (const block of message.content) { + if (block.type === "tool_use") { + console.log("Tool:", block.name); + console.log("Input:", block.input); + } +} +``` + +```shell tools.sh +curl -X POST http://localhost:11434/v1/messages \ +-H "Content-Type: application/json" \ +-d '{ + "model": "qwen3-coder", + "max_tokens": 1024, + "tools": [ + { + "name": "get_weather", + "description": "Get the current weather in a location", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state" + } + }, + "required": ["location"] + } + } + ], + "messages": [{ "role": "user", "content": "What is the weather in San Francisco?" }] +}' +``` + + + +## Using with Claude Code + +[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend: + +```shell +ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder +``` + +Or set the environment variables in your shell profile: + +```shell +export ANTHROPIC_BASE_URL=http://localhost:11434 +export ANTHROPIC_API_KEY=ollama +``` + +Then run Claude Code with any Ollama model: + +```shell +# Local models +claude --model qwen3-coder +claude --model gpt-oss:20b + +# Cloud models +claude --model glm-4.7:cloud +claude --model minimax-m2.1:cloud +``` + +## Endpoints + +### `/v1/messages` + +#### Supported features + +- [x] Messages +- [x] Streaming +- [x] System prompts +- [x] Multi-turn conversations +- [x] Vision (images) +- [x] Tools (function calling) +- [x] Tool results +- [x] Thinking/extended thinking + +#### Supported request fields + +- [x] `model` +- [x] `max_tokens` +- [x] `messages` + - [x] Text `content` + - [x] Image `content` (base64) + - [x] Array of content blocks + - [x] `tool_use` blocks + - [x] `tool_result` blocks + - [x] `thinking` blocks +- [x] `system` (string or array) +- [x] `stream` +- [x] `temperature` +- [x] `top_p` +- [x] `top_k` +- [x] `stop_sequences` +- [x] `tools` +- [x] `thinking` +- [ ] `tool_choice` +- [ ] `metadata` + +#### Supported response fields + +- [x] `id` +- [x] `type` +- [x] `role` +- [x] `model` +- [x] `content` (text, tool_use, thinking blocks) +- [x] `stop_reason` (end_turn, max_tokens, tool_use) +- [x] `usage` (input_tokens, output_tokens) + +#### Streaming events + +- [x] `message_start` +- [x] `content_block_start` +- [x] `content_block_delta` (text_delta, input_json_delta, thinking_delta) +- [x] `content_block_stop` +- [x] `message_delta` +- [x] `message_stop` +- [x] `ping` +- [x] `error` + +## Models + +Ollama supports both local and cloud models. + +### Local models + +Pull a local model before use: + +```shell +ollama pull qwen3-coder +``` + +Recommended local models: +- `qwen3-coder` - Excellent for coding tasks +- `gpt-oss:20b` - Strong general-purpose model + +### Cloud models + +Cloud models are available immediately without pulling: + +- `glm-4.7:cloud` - High-performance cloud model +- `minimax-m2.1:cloud` - Fast cloud model + +### Default model names + +For tooling that relies on default Anthropic model names such as `claude-3-5-sonnet`, use `ollama cp` to copy an existing model name: + +```shell +ollama cp qwen3-coder claude-3-5-sonnet +``` + +Afterwards, this new model name can be specified in the `model` field: + +```shell +curl http://localhost:11434/v1/messages \ + -H "Content-Type: application/json" \ + -d '{ + "model": "claude-3-5-sonnet", + "max_tokens": 1024, + "messages": [ + { + "role": "user", + "content": "Hello!" + } + ] + }' +``` + +## Differences from the Anthropic API + +### Behavior differences + +- API key is accepted but not validated +- `anthropic-version` header is accepted but not used +- Token counts are approximations based on the underlying model's tokenizer + +### Not supported + +The following Anthropic API features are not currently supported: + +| Feature | Description | +|---------|-------------| +| `/v1/messages/count_tokens` | Token counting endpoint | +| `tool_choice` | Forcing specific tool use or disabling tools | +| `metadata` | Request metadata (user_id) | +| Prompt caching | `cache_control` blocks for caching prefixes | +| Batches API | `/v1/messages/batches` for async batch processing | +| Citations | `citations` content blocks | +| PDF support | `document` content blocks with PDF files | +| Server-sent errors | `error` events during streaming (errors return HTTP status) | + +### Partial support + +| Feature | Status | +|---------|--------| +| Image content | Base64 images supported; URL images not supported | +| Extended thinking | Basic support; `budget_tokens` accepted but not enforced | diff --git a/docs/docs.json b/docs/docs.json index 71a6f17a0..810e94733 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -32,7 +32,9 @@ "codeblocks": "system" }, "contextual": { - "options": ["copy"] + "options": [ + "copy" + ] }, "navbar": { "links": [ @@ -52,7 +54,9 @@ "display": "simple" }, "examples": { - "languages": ["curl"] + "languages": [ + "curl" + ] } }, "redirects": [ @@ -97,6 +101,7 @@ { "group": "Integrations", "pages": [ + "/integrations/claude-code", "/integrations/vscode", "/integrations/jetbrains", "/integrations/codex", @@ -139,7 +144,8 @@ "/api/streaming", "/api/usage", "/api/errors", - "/api/openai-compatibility" + "/api/openai-compatibility", + "/api/anthropic-compatibility" ] }, { diff --git a/docs/integrations/claude-code.mdx b/docs/integrations/claude-code.mdx new file mode 100644 index 000000000..6d1d8322a --- /dev/null +++ b/docs/integrations/claude-code.mdx @@ -0,0 +1,69 @@ +--- +title: Claude Code +--- + +## Install + +Install [Claude Code](https://code.claude.com/docs/en/overview): + + + +```shell macOS / Linux +curl -fsSL https://claude.ai/install.sh | bash +``` + +```powershell Windows +irm https://claude.ai/install.ps1 | iex +``` + + + +## Usage with Ollama + +Claude Code connects to Ollama using the Anthropic-compatible API. + +1. Set the environment variables: + +```shell +export ANTHROPIC_BASE_URL=http://localhost:11434 +export ANTHROPIC_API_KEY=ollama +``` + +2. Run Claude Code with an Ollama model: + +```shell +claude --model qwen3-coder +``` + +Or run with environment variables inline: + +```shell +ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder +``` + +## Connecting to ollama.com + +1. Create an [API key](https://ollama.com/settings/keys) on ollama.com +2. Set the environment variables: + +```shell +export ANTHROPIC_BASE_URL=https://ollama.com +export ANTHROPIC_API_KEY= +``` + +3. Run Claude Code with a cloud model: + +```shell +claude --model glm-4.7:cloud +``` + +## Recommended Models + +### Cloud models +- `glm-4.7:cloud` - High-performance cloud model +- `minimax-m2.1:cloud` - Fast cloud model +- `qwen3-coder:480b` - Large coding model + +### Local models +- `qwen3-coder` - Excellent for coding tasks +- `gpt-oss:20b` - Strong general-purpose model diff --git a/middleware/anthropic.go b/middleware/anthropic.go new file mode 100644 index 000000000..f697f4078 --- /dev/null +++ b/middleware/anthropic.go @@ -0,0 +1,149 @@ +package middleware + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/gin-gonic/gin" + + "github.com/ollama/ollama/anthropic" + "github.com/ollama/ollama/api" +) + +// AnthropicWriter wraps the response writer to transform Ollama responses to Anthropic format +type AnthropicWriter struct { + BaseWriter + stream bool + id string + model string + converter *anthropic.StreamConverter +} + +func (w *AnthropicWriter) writeError(data []byte) (int, error) { + var errData struct { + Error string `json:"error"` + } + if err := json.Unmarshal(data, &errData); err != nil { + return 0, err + } + + w.ResponseWriter.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w.ResponseWriter).Encode(anthropic.NewError(w.ResponseWriter.Status(), errData.Error)) + if err != nil { + return 0, err + } + + return len(data), nil +} + +func (w *AnthropicWriter) writeEvent(eventType string, data any) error { + d, err := json.Marshal(data) + if err != nil { + return err + } + _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, d))) + if err != nil { + return err + } + if f, ok := w.ResponseWriter.(http.Flusher); ok { + f.Flush() + } + return nil +} + +func (w *AnthropicWriter) writeResponse(data []byte) (int, error) { + var chatResponse api.ChatResponse + err := json.Unmarshal(data, &chatResponse) + if err != nil { + return 0, err + } + + if w.stream { + w.ResponseWriter.Header().Set("Content-Type", "text/event-stream") + + events := w.converter.Process(chatResponse) + for _, event := range events { + if err := w.writeEvent(event.Event, event.Data); err != nil { + return 0, err + } + } + return len(data), nil + } + + w.ResponseWriter.Header().Set("Content-Type", "application/json") + response := anthropic.ToMessagesResponse(w.id, chatResponse) + return len(data), json.NewEncoder(w.ResponseWriter).Encode(response) +} + +func (w *AnthropicWriter) Write(data []byte) (int, error) { + code := w.ResponseWriter.Status() + if code != http.StatusOK { + return w.writeError(data) + } + + return w.writeResponse(data) +} + +// AnthropicMessagesMiddleware handles Anthropic Messages API requests +func AnthropicMessagesMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + var req anthropic.MessagesRequest + err := c.ShouldBindJSON(&req) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error())) + return + } + + if req.Model == "" { + c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "model is required")) + return + } + + if req.MaxTokens <= 0 { + c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "max_tokens is required and must be positive")) + return + } + + if len(req.Messages) == 0 { + c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "messages is required")) + return + } + + chatReq, err := anthropic.FromMessagesRequest(req) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error())) + return + } + + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(chatReq); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, anthropic.NewError(http.StatusInternalServerError, err.Error())) + return + } + + c.Request.Body = io.NopCloser(&b) + + messageID := anthropic.GenerateMessageID() + + w := &AnthropicWriter{ + BaseWriter: BaseWriter{ResponseWriter: c.Writer}, + stream: req.Stream, + id: messageID, + model: req.Model, + converter: anthropic.NewStreamConverter(messageID, req.Model), + } + + if req.Stream { + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + } + + c.Writer = w + + c.Next() + } +} diff --git a/middleware/anthropic_test.go b/middleware/anthropic_test.go new file mode 100644 index 000000000..40df7fbb4 --- /dev/null +++ b/middleware/anthropic_test.go @@ -0,0 +1,584 @@ +package middleware + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + + "github.com/ollama/ollama/anthropic" + "github.com/ollama/ollama/api" +) + +func captureAnthropicRequest(capturedRequest any) gin.HandlerFunc { + return func(c *gin.Context) { + bodyBytes, _ := io.ReadAll(c.Request.Body) + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + _ = json.Unmarshal(bodyBytes, capturedRequest) + c.Next() + } +} + +// testProps creates ToolPropertiesMap from a map (convenience function for tests) +func testProps(m map[string]api.ToolProperty) *api.ToolPropertiesMap { + props := api.NewToolPropertiesMap() + for k, v := range m { + props.Set(k, v) + } + return props +} + +func TestAnthropicMessagesMiddleware(t *testing.T) { + type testCase struct { + name string + body string + req api.ChatRequest + err anthropic.ErrorResponse + } + + var capturedRequest *api.ChatRequest + stream := true + + testCases := []testCase{ + { + name: "basic message", + body: `{ + "model": "test-model", + "max_tokens": 1024, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + {Role: "user", Content: "Hello"}, + }, + Options: map[string]any{"num_predict": 1024}, + Stream: &False, + }, + }, + { + name: "with system prompt", + body: `{ + "model": "test-model", + "max_tokens": 1024, + "system": "You are helpful.", + "messages": [ + {"role": "user", "content": "Hello"} + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + {Role: "system", Content: "You are helpful."}, + {Role: "user", Content: "Hello"}, + }, + Options: map[string]any{"num_predict": 1024}, + Stream: &False, + }, + }, + { + name: "with options", + body: `{ + "model": "test-model", + "max_tokens": 2048, + "temperature": 0.7, + "top_p": 0.9, + "top_k": 40, + "stop_sequences": ["\n", "END"], + "messages": [ + {"role": "user", "content": "Hello"} + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + {Role: "user", Content: "Hello"}, + }, + Options: map[string]any{ + "num_predict": 2048, + "temperature": 0.7, + "top_p": 0.9, + "top_k": 40, + "stop": []string{"\n", "END"}, + }, + Stream: &False, + }, + }, + { + name: "streaming", + body: `{ + "model": "test-model", + "max_tokens": 1024, + "stream": true, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + {Role: "user", Content: "Hello"}, + }, + Options: map[string]any{"num_predict": 1024}, + Stream: &stream, + }, + }, + { + name: "with tools", + body: `{ + "model": "test-model", + "max_tokens": 1024, + "messages": [ + {"role": "user", "content": "What's the weather?"} + ], + "tools": [{ + "name": "get_weather", + "description": "Get current weather", + "input_schema": { + "type": "object", + "properties": { + "location": {"type": "string"} + }, + "required": ["location"] + } + }] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + {Role: "user", Content: "What's the weather?"}, + }, + Tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get current weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Required: []string{"location"}, + Properties: testProps(map[string]api.ToolProperty{ + "location": {Type: api.PropertyType{"string"}}, + }), + }, + }, + }, + }, + Options: map[string]any{"num_predict": 1024}, + Stream: &False, + }, + }, + { + name: "with tool result", + body: `{ + "model": "test-model", + "max_tokens": 1024, + "messages": [ + {"role": "user", "content": "What's the weather?"}, + {"role": "assistant", "content": [ + {"type": "tool_use", "id": "call_123", "name": "get_weather", "input": {"location": "Paris"}} + ]}, + {"role": "user", "content": [ + {"type": "tool_result", "tool_use_id": "call_123", "content": "Sunny, 22°C"} + ]} + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + {Role: "user", Content: "What's the weather?"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + ID: "call_123", + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: testArgs(map[string]any{"location": "Paris"}), + }, + }, + }, + }, + {Role: "tool", Content: "Sunny, 22°C", ToolCallID: "call_123"}, + }, + Options: map[string]any{"num_predict": 1024}, + Stream: &False, + }, + }, + { + name: "with thinking enabled", + body: `{ + "model": "test-model", + "max_tokens": 1024, + "thinking": {"type": "enabled", "budget_tokens": 1000}, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + {Role: "user", Content: "Hello"}, + }, + Options: map[string]any{"num_predict": 1024}, + Stream: &False, + Think: &api.ThinkValue{Value: true}, + }, + }, + { + name: "missing model error", + body: `{ + "max_tokens": 1024, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }`, + err: anthropic.ErrorResponse{ + Type: "error", + Error: anthropic.Error{ + Type: "invalid_request_error", + Message: "model is required", + }, + }, + }, + { + name: "missing max_tokens error", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "Hello"} + ] + }`, + err: anthropic.ErrorResponse{ + Type: "error", + Error: anthropic.Error{ + Type: "invalid_request_error", + Message: "max_tokens is required and must be positive", + }, + }, + }, + { + name: "missing messages error", + body: `{ + "model": "test-model", + "max_tokens": 1024 + }`, + err: anthropic.ErrorResponse{ + Type: "error", + Error: anthropic.Error{ + Type: "invalid_request_error", + Message: "messages is required", + }, + }, + }, + { + name: "tool_use missing id error", + body: `{ + "model": "test-model", + "max_tokens": 1024, + "messages": [ + {"role": "assistant", "content": [ + {"type": "tool_use", "name": "test"} + ]} + ] + }`, + err: anthropic.ErrorResponse{ + Type: "error", + Error: anthropic.Error{ + Type: "invalid_request_error", + Message: "tool_use block missing required 'id' field", + }, + }, + }, + } + + endpoint := func(c *gin.Context) { + c.Status(http.StatusOK) + } + + gin.SetMode(gin.TestMode) + router := gin.New() + router.Use(AnthropicMessagesMiddleware(), captureAnthropicRequest(&capturedRequest)) + router.Handle(http.MethodPost, "/v1/messages", endpoint) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(tc.body)) + req.Header.Set("Content-Type", "application/json") + + defer func() { capturedRequest = nil }() + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if tc.err.Type != "" { + // Expect error + if resp.Code == http.StatusOK { + t.Fatalf("expected error response, got 200 OK") + } + var errResp anthropic.ErrorResponse + if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { + t.Fatalf("failed to unmarshal error: %v", err) + } + if errResp.Type != tc.err.Type { + t.Errorf("expected error type %q, got %q", tc.err.Type, errResp.Type) + } + if errResp.Error.Type != tc.err.Error.Type { + t.Errorf("expected error.type %q, got %q", tc.err.Error.Type, errResp.Error.Type) + } + if errResp.Error.Message != tc.err.Error.Message { + t.Errorf("expected error.message %q, got %q", tc.err.Error.Message, errResp.Error.Message) + } + return + } + + if resp.Code != http.StatusOK { + t.Fatalf("unexpected status code: %d, body: %s", resp.Code, resp.Body.String()) + } + + if capturedRequest == nil { + t.Fatal("request was not captured") + } + + // Compare relevant fields + if capturedRequest.Model != tc.req.Model { + t.Errorf("model mismatch: got %q, want %q", capturedRequest.Model, tc.req.Model) + } + + if diff := cmp.Diff(tc.req.Messages, capturedRequest.Messages, + cmpopts.IgnoreUnexported(api.ToolCallFunctionArguments{}, api.ToolPropertiesMap{})); diff != "" { + t.Errorf("messages mismatch (-want +got):\n%s", diff) + } + + if tc.req.Stream != nil && capturedRequest.Stream != nil { + if *tc.req.Stream != *capturedRequest.Stream { + t.Errorf("stream mismatch: got %v, want %v", *capturedRequest.Stream, *tc.req.Stream) + } + } + + if tc.req.Think != nil { + if capturedRequest.Think == nil { + t.Error("expected Think to be set") + } else if capturedRequest.Think.Value != tc.req.Think.Value { + t.Errorf("Think mismatch: got %v, want %v", capturedRequest.Think.Value, tc.req.Think.Value) + } + } + }) + } +} + +func TestAnthropicMessagesMiddleware_Headers(t *testing.T) { + gin.SetMode(gin.TestMode) + + t.Run("streaming sets correct headers", func(t *testing.T) { + router := gin.New() + router.Use(AnthropicMessagesMiddleware()) + router.POST("/v1/messages", func(c *gin.Context) { + // Check headers were set + if c.Writer.Header().Get("Content-Type") != "text/event-stream" { + t.Errorf("expected Content-Type text/event-stream, got %q", c.Writer.Header().Get("Content-Type")) + } + if c.Writer.Header().Get("Cache-Control") != "no-cache" { + t.Errorf("expected Cache-Control no-cache, got %q", c.Writer.Header().Get("Cache-Control")) + } + c.Status(http.StatusOK) + }) + + body := `{"model": "test", "max_tokens": 100, "stream": true, "messages": [{"role": "user", "content": "Hi"}]}` + req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + }) +} + +func TestAnthropicMessagesMiddleware_InvalidJSON(t *testing.T) { + gin.SetMode(gin.TestMode) + router := gin.New() + router.Use(AnthropicMessagesMiddleware()) + router.POST("/v1/messages", func(c *gin.Context) { + c.Status(http.StatusOK) + }) + + req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(`{invalid json`)) + req.Header.Set("Content-Type", "application/json") + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", resp.Code) + } + + var errResp anthropic.ErrorResponse + if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { + t.Fatalf("failed to unmarshal error: %v", err) + } + + if errResp.Type != "error" { + t.Errorf("expected type 'error', got %q", errResp.Type) + } + if errResp.Error.Type != "invalid_request_error" { + t.Errorf("expected error type 'invalid_request_error', got %q", errResp.Error.Type) + } +} + +func TestAnthropicWriter_NonStreaming(t *testing.T) { + gin.SetMode(gin.TestMode) + + router := gin.New() + router.Use(AnthropicMessagesMiddleware()) + router.POST("/v1/messages", func(c *gin.Context) { + // Simulate Ollama response + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + Content: "Hello there!", + }, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{ + PromptEvalCount: 10, + EvalCount: 5, + }, + } + data, _ := json.Marshal(resp) + c.Writer.WriteHeader(http.StatusOK) + _, _ = c.Writer.Write(data) + }) + + body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}` + req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", resp.Code) + } + + var result anthropic.MessagesResponse + if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if result.Type != "message" { + t.Errorf("expected type 'message', got %q", result.Type) + } + if result.Role != "assistant" { + t.Errorf("expected role 'assistant', got %q", result.Role) + } + if len(result.Content) != 1 { + t.Fatalf("expected 1 content block, got %d", len(result.Content)) + } + if result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" { + t.Errorf("expected text 'Hello there!', got %v", result.Content[0].Text) + } + if result.StopReason != "end_turn" { + t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason) + } + if result.Usage.InputTokens != 10 { + t.Errorf("expected input_tokens 10, got %d", result.Usage.InputTokens) + } + if result.Usage.OutputTokens != 5 { + t.Errorf("expected output_tokens 5, got %d", result.Usage.OutputTokens) + } +} + +// TestAnthropicWriter_ErrorFromRoutes tests error handling when routes.go sends +// gin.H{"error": "message"} without a StatusCode field (which is the common case) +func TestAnthropicWriter_ErrorFromRoutes(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + statusCode int + errorPayload any + wantErrorType string + wantMessage string + }{ + // routes.go sends errors without StatusCode in JSON, so we must use HTTP status + { + name: "404 with gin.H error (model not found)", + statusCode: http.StatusNotFound, + errorPayload: gin.H{"error": "model 'nonexistent' not found"}, + wantErrorType: "not_found_error", + wantMessage: "model 'nonexistent' not found", + }, + { + name: "400 with gin.H error (bad request)", + statusCode: http.StatusBadRequest, + errorPayload: gin.H{"error": "model is required"}, + wantErrorType: "invalid_request_error", + wantMessage: "model is required", + }, + { + name: "500 with gin.H error (internal error)", + statusCode: http.StatusInternalServerError, + errorPayload: gin.H{"error": "something went wrong"}, + wantErrorType: "api_error", + wantMessage: "something went wrong", + }, + { + name: "404 with api.StatusError", + statusCode: http.StatusNotFound, + errorPayload: api.StatusError{ + StatusCode: http.StatusNotFound, + ErrorMessage: "model not found via StatusError", + }, + wantErrorType: "not_found_error", + wantMessage: "model not found via StatusError", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + router := gin.New() + router.Use(AnthropicMessagesMiddleware()) + router.POST("/v1/messages", func(c *gin.Context) { + // Simulate what routes.go does - set status and write error JSON + data, _ := json.Marshal(tt.errorPayload) + c.Writer.WriteHeader(tt.statusCode) + _, _ = c.Writer.Write(data) + }) + + body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}` + req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != tt.statusCode { + t.Errorf("expected status %d, got %d", tt.statusCode, resp.Code) + } + + var errResp anthropic.ErrorResponse + if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { + t.Fatalf("failed to unmarshal error response: %v\nbody: %s", err, resp.Body.String()) + } + + if errResp.Type != "error" { + t.Errorf("expected type 'error', got %q", errResp.Type) + } + if errResp.Error.Type != tt.wantErrorType { + t.Errorf("expected error type %q, got %q", tt.wantErrorType, errResp.Error.Type) + } + if errResp.Error.Message != tt.wantMessage { + t.Errorf("expected message %q, got %q", tt.wantMessage, errResp.Error.Message) + } + }) + } +} diff --git a/server/routes.go b/server/routes.go index 977a13ff2..8e199bada 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1544,6 +1544,9 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler) r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler) + // Inference (Anthropic compatibility) + r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler) + if rc != nil { // wrap old with new rs := ®istry.Local{