diff --git a/api/types.go b/api/types.go index 63b898975..2434fe478 100644 --- a/api/types.go +++ b/api/types.go @@ -3,6 +3,7 @@ package api import ( "encoding/json" "fmt" + "iter" "log/slog" "math" "os" @@ -14,6 +15,7 @@ import ( "github.com/google/uuid" "github.com/ollama/ollama/envconfig" + "github.com/ollama/ollama/internal/orderedmap" "github.com/ollama/ollama/types/model" ) @@ -227,13 +229,79 @@ type ToolCallFunction struct { Arguments ToolCallFunctionArguments `json:"arguments"` } -type ToolCallFunctionArguments map[string]any +// ToolCallFunctionArguments holds tool call arguments in insertion order. +type ToolCallFunctionArguments struct { + om *orderedmap.Map[string, any] +} + +// NewToolCallFunctionArguments creates a new empty ToolCallFunctionArguments. +func NewToolCallFunctionArguments() ToolCallFunctionArguments { + return ToolCallFunctionArguments{om: orderedmap.New[string, any]()} +} + +// Get retrieves a value by key. +func (t *ToolCallFunctionArguments) Get(key string) (any, bool) { + if t == nil || t.om == nil { + return nil, false + } + return t.om.Get(key) +} + +// Set sets a key-value pair, preserving insertion order. +func (t *ToolCallFunctionArguments) Set(key string, value any) { + if t == nil { + return + } + if t.om == nil { + t.om = orderedmap.New[string, any]() + } + t.om.Set(key, value) +} + +// Len returns the number of arguments. +func (t *ToolCallFunctionArguments) Len() int { + if t == nil || t.om == nil { + return 0 + } + return t.om.Len() +} + +// All returns an iterator over all key-value pairs in insertion order. +func (t *ToolCallFunctionArguments) All() iter.Seq2[string, any] { + if t == nil || t.om == nil { + return func(yield func(string, any) bool) {} + } + return t.om.All() +} + +// ToMap returns a regular map (order not preserved). +func (t *ToolCallFunctionArguments) ToMap() map[string]any { + if t == nil || t.om == nil { + return nil + } + return t.om.ToMap() +} func (t *ToolCallFunctionArguments) String() string { - bts, _ := json.Marshal(t) + if t == nil || t.om == nil { + return "{}" + } + bts, _ := json.Marshal(t.om) return string(bts) } +func (t *ToolCallFunctionArguments) UnmarshalJSON(data []byte) error { + t.om = orderedmap.New[string, any]() + return json.Unmarshal(data, t.om) +} + +func (t ToolCallFunctionArguments) MarshalJSON() ([]byte, error) { + if t.om == nil { + return []byte("{}"), nil + } + return json.Marshal(t.om) +} + type Tool struct { Type string `json:"type"` Items any `json:"items,omitempty"` @@ -282,13 +350,78 @@ func (pt PropertyType) String() string { return fmt.Sprintf("%v", []string(pt)) } +// ToolPropertiesMap holds tool properties in insertion order. +type ToolPropertiesMap struct { + om *orderedmap.Map[string, ToolProperty] +} + +// NewToolPropertiesMap creates a new empty ToolPropertiesMap. +func NewToolPropertiesMap() *ToolPropertiesMap { + return &ToolPropertiesMap{om: orderedmap.New[string, ToolProperty]()} +} + +// Get retrieves a property by name. +func (t *ToolPropertiesMap) Get(key string) (ToolProperty, bool) { + if t == nil || t.om == nil { + return ToolProperty{}, false + } + return t.om.Get(key) +} + +// Set sets a property, preserving insertion order. +func (t *ToolPropertiesMap) Set(key string, value ToolProperty) { + if t == nil { + return + } + if t.om == nil { + t.om = orderedmap.New[string, ToolProperty]() + } + t.om.Set(key, value) +} + +// Len returns the number of properties. +func (t *ToolPropertiesMap) Len() int { + if t == nil || t.om == nil { + return 0 + } + return t.om.Len() +} + +// All returns an iterator over all properties in insertion order. +func (t *ToolPropertiesMap) All() iter.Seq2[string, ToolProperty] { + if t == nil || t.om == nil { + return func(yield func(string, ToolProperty) bool) {} + } + return t.om.All() +} + +// ToMap returns a regular map (order not preserved). +func (t *ToolPropertiesMap) ToMap() map[string]ToolProperty { + if t == nil || t.om == nil { + return nil + } + return t.om.ToMap() +} + +func (t ToolPropertiesMap) MarshalJSON() ([]byte, error) { + if t.om == nil { + return []byte("null"), nil + } + return json.Marshal(t.om) +} + +func (t *ToolPropertiesMap) UnmarshalJSON(data []byte) error { + t.om = orderedmap.New[string, ToolProperty]() + return json.Unmarshal(data, t.om) +} + type ToolProperty struct { - AnyOf []ToolProperty `json:"anyOf,omitempty"` - Type PropertyType `json:"type,omitempty"` - Items any `json:"items,omitempty"` - Description string `json:"description,omitempty"` - Enum []any `json:"enum,omitempty"` - Properties map[string]ToolProperty `json:"properties,omitempty"` + AnyOf []ToolProperty `json:"anyOf,omitempty"` + Type PropertyType `json:"type,omitempty"` + Items any `json:"items,omitempty"` + Description string `json:"description,omitempty"` + Enum []any `json:"enum,omitempty"` + Properties *ToolPropertiesMap `json:"properties,omitempty"` } // ToTypeScriptType converts a ToolProperty to a TypeScript type string @@ -337,11 +470,11 @@ func mapToTypeScriptType(jsonType string) string { } type ToolFunctionParameters struct { - Type string `json:"type"` - Defs any `json:"$defs,omitempty"` - Items any `json:"items,omitempty"` - Required []string `json:"required,omitempty"` - Properties map[string]ToolProperty `json:"properties"` + Type string `json:"type"` + Defs any `json:"$defs,omitempty"` + Items any `json:"items,omitempty"` + Required []string `json:"required,omitempty"` + Properties *ToolPropertiesMap `json:"properties"` } func (t *ToolFunctionParameters) String() string { diff --git a/api/types_test.go b/api/types_test.go index da1581f48..69d9c5a3d 100644 --- a/api/types_test.go +++ b/api/types_test.go @@ -11,6 +11,24 @@ import ( "github.com/stretchr/testify/require" ) +// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved) +func testPropsMap(m map[string]ToolProperty) *ToolPropertiesMap { + props := NewToolPropertiesMap() + for k, v := range m { + props.Set(k, v) + } + return props +} + +// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests, order not preserved) +func testArgs(m map[string]any) ToolCallFunctionArguments { + args := NewToolCallFunctionArguments() + for k, v := range m { + args.Set(k, v) + } + return args +} + func TestKeepAliveParsingFromJSON(t *testing.T) { tests := []struct { name string @@ -309,9 +327,9 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) { input: ToolFunctionParameters{ Type: "object", Required: []string{"name"}, - Properties: map[string]ToolProperty{ + Properties: testPropsMap(map[string]ToolProperty{ "name": {Type: PropertyType{"string"}}, - }, + }), }, expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string"}}}`, }, @@ -319,9 +337,9 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) { name: "no required", input: ToolFunctionParameters{ Type: "object", - Properties: map[string]ToolProperty{ + Properties: testPropsMap(map[string]ToolProperty{ "name": {Type: PropertyType{"string"}}, - }, + }), }, expected: `{"type":"object","properties":{"name":{"type":"string"}}}`, }, @@ -339,7 +357,7 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) { func TestToolCallFunction_IndexAlwaysMarshals(t *testing.T) { fn := ToolCallFunction{ Name: "echo", - Arguments: ToolCallFunctionArguments{"message": "hi"}, + Arguments: testArgs(map[string]any{"message": "hi"}), } data, err := json.Marshal(fn) @@ -529,7 +547,7 @@ func TestToolPropertyNestedProperties(t *testing.T) { expected: ToolProperty{ Type: PropertyType{"object"}, Description: "Location details", - Properties: map[string]ToolProperty{ + Properties: testPropsMap(map[string]ToolProperty{ "address": { Type: PropertyType{"string"}, Description: "Street address", @@ -538,7 +556,7 @@ func TestToolPropertyNestedProperties(t *testing.T) { Type: PropertyType{"string"}, Description: "City name", }, - }, + }), }, }, { @@ -566,22 +584,22 @@ func TestToolPropertyNestedProperties(t *testing.T) { expected: ToolProperty{ Type: PropertyType{"object"}, Description: "Event", - Properties: map[string]ToolProperty{ + Properties: testPropsMap(map[string]ToolProperty{ "location": { Type: PropertyType{"object"}, Description: "Location", - Properties: map[string]ToolProperty{ + Properties: testPropsMap(map[string]ToolProperty{ "coordinates": { Type: PropertyType{"object"}, Description: "GPS coordinates", - Properties: map[string]ToolProperty{ + Properties: testPropsMap(map[string]ToolProperty{ "lat": {Type: PropertyType{"number"}, Description: "Latitude"}, "lng": {Type: PropertyType{"number"}, Description: "Longitude"}, - }, + }), }, - }, + }), }, - }, + }), }, }, } @@ -591,7 +609,13 @@ func TestToolPropertyNestedProperties(t *testing.T) { var prop ToolProperty err := json.Unmarshal([]byte(tt.input), &prop) require.NoError(t, err) - assert.Equal(t, tt.expected, prop) + + // Compare JSON representations since pointer comparison doesn't work + expectedJSON, err := json.Marshal(tt.expected) + require.NoError(t, err) + actualJSON, err := json.Marshal(prop) + require.NoError(t, err) + assert.JSONEq(t, string(expectedJSON), string(actualJSON)) // Round-trip test: marshal and unmarshal again data, err := json.Marshal(prop) @@ -600,7 +624,10 @@ func TestToolPropertyNestedProperties(t *testing.T) { var prop2 ToolProperty err = json.Unmarshal(data, &prop2) require.NoError(t, err) - assert.Equal(t, tt.expected, prop2) + + prop2JSON, err := json.Marshal(prop2) + require.NoError(t, err) + assert.JSONEq(t, string(expectedJSON), string(prop2JSON)) }) } } @@ -616,12 +643,12 @@ func TestToolFunctionParameters_String(t *testing.T) { params: ToolFunctionParameters{ Type: "object", Required: []string{"name"}, - Properties: map[string]ToolProperty{ + Properties: testPropsMap(map[string]ToolProperty{ "name": { Type: PropertyType{"string"}, Description: "The name of the person", }, - }, + }), }, expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string","description":"The name of the person"}}}`, }, @@ -638,7 +665,7 @@ func TestToolFunctionParameters_String(t *testing.T) { s.Self = s return s }(), - Properties: map[string]ToolProperty{}, + Properties: testPropsMap(map[string]ToolProperty{}), }, expected: "", }, @@ -651,3 +678,235 @@ func TestToolFunctionParameters_String(t *testing.T) { }) } } + +func TestToolCallFunctionArguments_OrderPreservation(t *testing.T) { + t.Run("marshal preserves insertion order", func(t *testing.T) { + args := NewToolCallFunctionArguments() + args.Set("zebra", "z") + args.Set("apple", "a") + args.Set("mango", "m") + + data, err := json.Marshal(args) + require.NoError(t, err) + + // Should preserve insertion order, not alphabetical + assert.Equal(t, `{"zebra":"z","apple":"a","mango":"m"}`, string(data)) + }) + + t.Run("unmarshal preserves JSON order", func(t *testing.T) { + jsonData := `{"zebra":"z","apple":"a","mango":"m"}` + + var args ToolCallFunctionArguments + err := json.Unmarshal([]byte(jsonData), &args) + require.NoError(t, err) + + // Verify iteration order matches JSON order + var keys []string + for k := range args.All() { + keys = append(keys, k) + } + assert.Equal(t, []string{"zebra", "apple", "mango"}, keys) + }) + + t.Run("round trip preserves order", func(t *testing.T) { + original := `{"z":1,"a":2,"m":3,"b":4}` + + var args ToolCallFunctionArguments + err := json.Unmarshal([]byte(original), &args) + require.NoError(t, err) + + data, err := json.Marshal(args) + require.NoError(t, err) + + assert.Equal(t, original, string(data)) + }) + + t.Run("String method returns ordered JSON", func(t *testing.T) { + args := NewToolCallFunctionArguments() + args.Set("c", 3) + args.Set("a", 1) + args.Set("b", 2) + + assert.Equal(t, `{"c":3,"a":1,"b":2}`, args.String()) + }) + + t.Run("Get retrieves correct values", func(t *testing.T) { + args := NewToolCallFunctionArguments() + args.Set("key1", "value1") + args.Set("key2", 42) + + v, ok := args.Get("key1") + assert.True(t, ok) + assert.Equal(t, "value1", v) + + v, ok = args.Get("key2") + assert.True(t, ok) + assert.Equal(t, 42, v) + + _, ok = args.Get("nonexistent") + assert.False(t, ok) + }) + + t.Run("Len returns correct count", func(t *testing.T) { + args := NewToolCallFunctionArguments() + assert.Equal(t, 0, args.Len()) + + args.Set("a", 1) + assert.Equal(t, 1, args.Len()) + + args.Set("b", 2) + assert.Equal(t, 2, args.Len()) + }) + + t.Run("empty args marshal to empty object", func(t *testing.T) { + args := NewToolCallFunctionArguments() + data, err := json.Marshal(args) + require.NoError(t, err) + assert.Equal(t, `{}`, string(data)) + }) + + t.Run("zero value args marshal to empty object", func(t *testing.T) { + var args ToolCallFunctionArguments + assert.Equal(t, "{}", args.String()) + }) +} + +func TestToolPropertiesMap_OrderPreservation(t *testing.T) { + t.Run("marshal preserves insertion order", func(t *testing.T) { + props := NewToolPropertiesMap() + props.Set("zebra", ToolProperty{Type: PropertyType{"string"}}) + props.Set("apple", ToolProperty{Type: PropertyType{"number"}}) + props.Set("mango", ToolProperty{Type: PropertyType{"boolean"}}) + + data, err := json.Marshal(props) + require.NoError(t, err) + + // Should preserve insertion order, not alphabetical + expected := `{"zebra":{"type":"string"},"apple":{"type":"number"},"mango":{"type":"boolean"}}` + assert.Equal(t, expected, string(data)) + }) + + t.Run("unmarshal preserves JSON order", func(t *testing.T) { + jsonData := `{"zebra":{"type":"string"},"apple":{"type":"number"},"mango":{"type":"boolean"}}` + + var props ToolPropertiesMap + err := json.Unmarshal([]byte(jsonData), &props) + require.NoError(t, err) + + // Verify iteration order matches JSON order + var keys []string + for k := range props.All() { + keys = append(keys, k) + } + assert.Equal(t, []string{"zebra", "apple", "mango"}, keys) + }) + + t.Run("round trip preserves order", func(t *testing.T) { + original := `{"z":{"type":"string"},"a":{"type":"number"},"m":{"type":"boolean"}}` + + var props ToolPropertiesMap + err := json.Unmarshal([]byte(original), &props) + require.NoError(t, err) + + data, err := json.Marshal(props) + require.NoError(t, err) + + assert.Equal(t, original, string(data)) + }) + + t.Run("Get retrieves correct values", func(t *testing.T) { + props := NewToolPropertiesMap() + props.Set("name", ToolProperty{Type: PropertyType{"string"}, Description: "The name"}) + props.Set("age", ToolProperty{Type: PropertyType{"integer"}, Description: "The age"}) + + v, ok := props.Get("name") + assert.True(t, ok) + assert.Equal(t, "The name", v.Description) + + v, ok = props.Get("age") + assert.True(t, ok) + assert.Equal(t, "The age", v.Description) + + _, ok = props.Get("nonexistent") + assert.False(t, ok) + }) + + t.Run("Len returns correct count", func(t *testing.T) { + props := NewToolPropertiesMap() + assert.Equal(t, 0, props.Len()) + + props.Set("a", ToolProperty{}) + assert.Equal(t, 1, props.Len()) + + props.Set("b", ToolProperty{}) + assert.Equal(t, 2, props.Len()) + }) + + t.Run("nil props marshal to null", func(t *testing.T) { + var props *ToolPropertiesMap + data, err := json.Marshal(props) + require.NoError(t, err) + assert.Equal(t, `null`, string(data)) + }) + + t.Run("ToMap returns regular map", func(t *testing.T) { + props := NewToolPropertiesMap() + props.Set("a", ToolProperty{Type: PropertyType{"string"}}) + props.Set("b", ToolProperty{Type: PropertyType{"number"}}) + + m := props.ToMap() + assert.Equal(t, 2, len(m)) + assert.Equal(t, PropertyType{"string"}, m["a"].Type) + assert.Equal(t, PropertyType{"number"}, m["b"].Type) + }) +} + +func TestToolCallFunctionArguments_ComplexValues(t *testing.T) { + t.Run("nested objects preserve order", func(t *testing.T) { + jsonData := `{"outer":{"z":1,"a":2},"simple":"value"}` + + var args ToolCallFunctionArguments + err := json.Unmarshal([]byte(jsonData), &args) + require.NoError(t, err) + + // Outer keys should be in order + var keys []string + for k := range args.All() { + keys = append(keys, k) + } + assert.Equal(t, []string{"outer", "simple"}, keys) + }) + + t.Run("arrays as values", func(t *testing.T) { + args := NewToolCallFunctionArguments() + args.Set("items", []string{"a", "b", "c"}) + args.Set("numbers", []int{1, 2, 3}) + + data, err := json.Marshal(args) + require.NoError(t, err) + + assert.Equal(t, `{"items":["a","b","c"],"numbers":[1,2,3]}`, string(data)) + }) +} + +func TestToolPropertiesMap_NestedProperties(t *testing.T) { + t.Run("nested properties preserve order", func(t *testing.T) { + props := NewToolPropertiesMap() + + nestedProps := NewToolPropertiesMap() + nestedProps.Set("z_field", ToolProperty{Type: PropertyType{"string"}}) + nestedProps.Set("a_field", ToolProperty{Type: PropertyType{"number"}}) + + props.Set("outer", ToolProperty{ + Type: PropertyType{"object"}, + Properties: nestedProps, + }) + + data, err := json.Marshal(props) + require.NoError(t, err) + + // Both outer and inner should preserve order + expected := `{"outer":{"type":"object","properties":{"z_field":{"type":"string"},"a_field":{"type":"number"}}}}` + assert.Equal(t, expected, string(data)) + }) +} diff --git a/app/ui/ui.go b/app/ui/ui.go index 26de71422..0b32f917e 100644 --- a/app/ui/ui.go +++ b/app/ui/ui.go @@ -997,7 +997,7 @@ func (s *Server) chat(w http.ResponseWriter, r *http.Request) error { for _, toolCall := range res.Message.ToolCalls { // continues loop as tools were executed toolsExecuted = true - result, content, err := registry.Execute(ctx, toolCall.Function.Name, toolCall.Function.Arguments) + result, content, err := registry.Execute(ctx, toolCall.Function.Name, toolCall.Function.Arguments.ToMap()) if err != nil { errContent := fmt.Sprintf("Error: %v", err) toolErrMsg := store.NewMessage("tool", errContent, nil) @@ -1558,13 +1558,13 @@ func convertToOllamaTool(toolSchema map[string]any) api.Tool { tool.Function.Parameters.Type = "object" tool.Function.Parameters.Required = []string{} - tool.Function.Parameters.Properties = make(map[string]api.ToolProperty) + tool.Function.Parameters.Properties = api.NewToolPropertiesMap() if schemaProps, ok := toolSchema["schema"].(map[string]any); ok { tool.Function.Parameters.Type = getStringFromMap(schemaProps, "type", "object") if props, ok := schemaProps["properties"].(map[string]any); ok { - tool.Function.Parameters.Properties = make(map[string]api.ToolProperty) + tool.Function.Parameters.Properties = api.NewToolPropertiesMap() for propName, propDef := range props { if propMap, ok := propDef.(map[string]any); ok { @@ -1572,7 +1572,7 @@ func convertToOllamaTool(toolSchema map[string]any) api.Tool { Type: api.PropertyType{getStringFromMap(propMap, "type", "string")}, Description: getStringFromMap(propMap, "description", ""), } - tool.Function.Parameters.Properties[propName] = prop + tool.Function.Parameters.Properties.Set(propName, prop) } } } diff --git a/go.mod b/go.mod index f7c9ff295..b912a9a0a 100644 --- a/go.mod +++ b/go.mod @@ -28,6 +28,7 @@ require ( github.com/nlpodyssey/gopickle v0.3.0 github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c github.com/tkrajina/typescriptify-golang-structs v0.2.0 + github.com/wk8/go-ordered-map/v2 v2.1.8 golang.org/x/image v0.22.0 golang.org/x/mod v0.30.0 golang.org/x/tools v0.38.0 @@ -36,6 +37,8 @@ require ( require ( github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect github.com/bytedance/sonic/loader v0.1.1 // indirect github.com/chewxy/hm v1.0.0 // indirect github.com/chewxy/math32 v1.11.0 // indirect @@ -45,6 +48,7 @@ require ( github.com/gogo/protobuf v1.3.2 // indirect github.com/google/flatbuffers v24.3.25+incompatible // indirect github.com/kr/text v0.2.0 // indirect + github.com/mailru/easyjson v0.7.7 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.2.0 // indirect diff --git a/go.sum b/go.sum index 936c040a0..83014fc5b 100644 --- a/go.sum +++ b/go.sum @@ -14,7 +14,11 @@ github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 h1:q4dksr6IC github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs= github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q= github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= @@ -123,6 +127,7 @@ github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= @@ -143,6 +148,8 @@ github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728 h1:QwWKgMY28TAXaDl+ github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728/go.mod h1:1fEHWurg7pvf5SG6XNE5Q8UZmOwex51Mkx3SLhrW5B4= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= @@ -207,6 +214,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY= diff --git a/integration/tools_test.go b/integration/tools_test.go index fa37d8f3a..e74f40413 100644 --- a/integration/tools_test.go +++ b/integration/tools_test.go @@ -11,6 +11,15 @@ import ( "github.com/ollama/ollama/api" ) +// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests) +func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap { + props := api.NewToolPropertiesMap() + for k, v := range m { + props.Set(k, v) + } + return props +} + func TestAPIToolCalling(t *testing.T) { initialTimeout := 60 * time.Second streamTimeout := 60 * time.Second @@ -57,12 +66,12 @@ func TestAPIToolCalling(t *testing.T) { Parameters: api.ToolFunctionParameters{ Type: "object", Required: []string{"location"}, - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "location": { Type: api.PropertyType{"string"}, Description: "The city and state, e.g. San Francisco, CA", }, - }, + }), }, }, }, diff --git a/internal/orderedmap/orderedmap.go b/internal/orderedmap/orderedmap.go new file mode 100644 index 000000000..5ee5a9403 --- /dev/null +++ b/internal/orderedmap/orderedmap.go @@ -0,0 +1,94 @@ +// Package orderedmap provides a generic ordered map that maintains insertion order. +// It wraps github.com/wk8/go-ordered-map/v2 to encapsulate the dependency. +package orderedmap + +import ( + "encoding/json" + "iter" + + orderedmap "github.com/wk8/go-ordered-map/v2" +) + +// Map is a generic ordered map that maintains insertion order. +type Map[K comparable, V any] struct { + om *orderedmap.OrderedMap[K, V] +} + +// New creates a new empty ordered map. +func New[K comparable, V any]() *Map[K, V] { + return &Map[K, V]{ + om: orderedmap.New[K, V](), + } +} + +// Get retrieves a value by key. +func (m *Map[K, V]) Get(key K) (V, bool) { + if m == nil || m.om == nil { + var zero V + return zero, false + } + return m.om.Get(key) +} + +// Set sets a key-value pair. If the key already exists, its value is updated +// but its position in the iteration order is preserved. If the key is new, +// it is appended to the end. +func (m *Map[K, V]) Set(key K, value V) { + if m == nil { + return + } + if m.om == nil { + m.om = orderedmap.New[K, V]() + } + m.om.Set(key, value) +} + +// Len returns the number of entries. +func (m *Map[K, V]) Len() int { + if m == nil || m.om == nil { + return 0 + } + return m.om.Len() +} + +// All returns an iterator over all key-value pairs in insertion order. +func (m *Map[K, V]) All() iter.Seq2[K, V] { + return func(yield func(K, V) bool) { + if m == nil || m.om == nil { + return + } + for pair := m.om.Oldest(); pair != nil; pair = pair.Next() { + if !yield(pair.Key, pair.Value) { + return + } + } + } +} + +// ToMap converts to a regular Go map. +// Note: The resulting map does not preserve order. +func (m *Map[K, V]) ToMap() map[K]V { + if m == nil || m.om == nil { + return nil + } + result := make(map[K]V, m.om.Len()) + for pair := m.om.Oldest(); pair != nil; pair = pair.Next() { + result[pair.Key] = pair.Value + } + return result +} + +// MarshalJSON implements json.Marshaler. The JSON output preserves key order. +func (m *Map[K, V]) MarshalJSON() ([]byte, error) { + if m == nil || m.om == nil { + return []byte("null"), nil + } + return json.Marshal(m.om) +} + +// UnmarshalJSON implements json.Unmarshaler. The insertion order matches the +// order of keys in the JSON input. +func (m *Map[K, V]) UnmarshalJSON(data []byte) error { + m.om = orderedmap.New[K, V]() + return json.Unmarshal(data, &m.om) +} diff --git a/internal/orderedmap/orderedmap_test.go b/internal/orderedmap/orderedmap_test.go new file mode 100644 index 000000000..9886d24b7 --- /dev/null +++ b/internal/orderedmap/orderedmap_test.go @@ -0,0 +1,348 @@ +package orderedmap + +import ( + "encoding/json" + "slices" + "testing" +) + +func TestMap_BasicOperations(t *testing.T) { + m := New[string, int]() + + // Test empty map + if m.Len() != 0 { + t.Errorf("expected Len() = 0, got %d", m.Len()) + } + v, ok := m.Get("a") + if ok { + t.Error("expected Get on empty map to return false") + } + if v != 0 { + t.Errorf("expected zero value, got %d", v) + } + + // Test Set and Get + m.Set("a", 1) + m.Set("b", 2) + m.Set("c", 3) + + if m.Len() != 3 { + t.Errorf("expected Len() = 3, got %d", m.Len()) + } + + v, ok = m.Get("a") + if !ok || v != 1 { + t.Errorf("expected Get(a) = (1, true), got (%d, %v)", v, ok) + } + + v, ok = m.Get("b") + if !ok || v != 2 { + t.Errorf("expected Get(b) = (2, true), got (%d, %v)", v, ok) + } + + v, ok = m.Get("c") + if !ok || v != 3 { + t.Errorf("expected Get(c) = (3, true), got (%d, %v)", v, ok) + } + + // Test updating existing key preserves position + m.Set("a", 10) + v, ok = m.Get("a") + if !ok || v != 10 { + t.Errorf("expected Get(a) = (10, true), got (%d, %v)", v, ok) + } + if m.Len() != 3 { + t.Errorf("expected Len() = 3 after update, got %d", m.Len()) + } +} + +func TestMap_InsertionOrderPreserved(t *testing.T) { + m := New[string, int]() + + // Insert in non-alphabetical order + m.Set("z", 1) + m.Set("a", 2) + m.Set("m", 3) + m.Set("b", 4) + + // Verify iteration order matches insertion order + var keys []string + var values []int + for k, v := range m.All() { + keys = append(keys, k) + values = append(values, v) + } + + expectedKeys := []string{"z", "a", "m", "b"} + expectedValues := []int{1, 2, 3, 4} + + if !slices.Equal(keys, expectedKeys) { + t.Errorf("expected keys %v, got %v", expectedKeys, keys) + } + if !slices.Equal(values, expectedValues) { + t.Errorf("expected values %v, got %v", expectedValues, values) + } +} + +func TestMap_UpdatePreservesPosition(t *testing.T) { + m := New[string, int]() + + m.Set("first", 1) + m.Set("second", 2) + m.Set("third", 3) + + // Update middle element + m.Set("second", 20) + + var keys []string + for k := range m.All() { + keys = append(keys, k) + } + + // Order should still be first, second, third + expected := []string{"first", "second", "third"} + if !slices.Equal(keys, expected) { + t.Errorf("expected keys %v, got %v", expected, keys) + } +} + +func TestMap_MarshalJSON_PreservesOrder(t *testing.T) { + m := New[string, int]() + + // Insert in non-alphabetical order + m.Set("z", 1) + m.Set("a", 2) + m.Set("m", 3) + + data, err := json.Marshal(m) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + + // JSON should preserve insertion order, not alphabetical + expected := `{"z":1,"a":2,"m":3}` + if string(data) != expected { + t.Errorf("expected %s, got %s", expected, string(data)) + } +} + +func TestMap_UnmarshalJSON_PreservesOrder(t *testing.T) { + // JSON with non-alphabetical key order + jsonData := `{"z":1,"a":2,"m":3}` + + m := New[string, int]() + if err := json.Unmarshal([]byte(jsonData), m); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + + // Verify iteration order matches JSON order + var keys []string + for k := range m.All() { + keys = append(keys, k) + } + + expected := []string{"z", "a", "m"} + if !slices.Equal(keys, expected) { + t.Errorf("expected keys %v, got %v", expected, keys) + } +} + +func TestMap_JSONRoundTrip(t *testing.T) { + // Test that unmarshal -> marshal produces identical JSON + original := `{"zebra":"z","apple":"a","mango":"m","banana":"b"}` + + m := New[string, string]() + if err := json.Unmarshal([]byte(original), m); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + + data, err := json.Marshal(m) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + + if string(data) != original { + t.Errorf("round trip failed: expected %s, got %s", original, string(data)) + } +} + +func TestMap_ToMap(t *testing.T) { + m := New[string, int]() + m.Set("a", 1) + m.Set("b", 2) + + regular := m.ToMap() + + if len(regular) != 2 { + t.Errorf("expected len 2, got %d", len(regular)) + } + if regular["a"] != 1 { + t.Errorf("expected regular[a] = 1, got %d", regular["a"]) + } + if regular["b"] != 2 { + t.Errorf("expected regular[b] = 2, got %d", regular["b"]) + } +} + +func TestMap_NilSafety(t *testing.T) { + var m *Map[string, int] + + // All operations should be safe on nil + if m.Len() != 0 { + t.Errorf("expected Len() = 0 on nil map, got %d", m.Len()) + } + + v, ok := m.Get("a") + if ok { + t.Error("expected Get on nil map to return false") + } + if v != 0 { + t.Errorf("expected zero value from nil map, got %d", v) + } + + // Set on nil is a no-op + m.Set("a", 1) + if m.Len() != 0 { + t.Errorf("expected Len() = 0 after Set on nil, got %d", m.Len()) + } + + // All returns empty iterator + var keys []string + for k := range m.All() { + keys = append(keys, k) + } + if len(keys) != 0 { + t.Errorf("expected empty iteration on nil map, got %v", keys) + } + + // ToMap returns nil + if m.ToMap() != nil { + t.Error("expected ToMap to return nil on nil map") + } + + // MarshalJSON returns null + data, err := json.Marshal(m) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + if string(data) != "null" { + t.Errorf("expected null, got %s", string(data)) + } +} + +func TestMap_EmptyMapMarshal(t *testing.T) { + m := New[string, int]() + + data, err := json.Marshal(m) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + if string(data) != "{}" { + t.Errorf("expected {}, got %s", string(data)) + } +} + +func TestMap_NestedValues(t *testing.T) { + m := New[string, any]() + m.Set("string", "hello") + m.Set("number", 42) + m.Set("bool", true) + m.Set("nested", map[string]int{"x": 1}) + + data, err := json.Marshal(m) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + + expected := `{"string":"hello","number":42,"bool":true,"nested":{"x":1}}` + if string(data) != expected { + t.Errorf("expected %s, got %s", expected, string(data)) + } +} + +func TestMap_AllIteratorEarlyExit(t *testing.T) { + m := New[string, int]() + m.Set("a", 1) + m.Set("b", 2) + m.Set("c", 3) + m.Set("d", 4) + + // Collect only first 2 + var keys []string + for k := range m.All() { + keys = append(keys, k) + if len(keys) == 2 { + break + } + } + + expected := []string{"a", "b"} + if !slices.Equal(keys, expected) { + t.Errorf("expected %v, got %v", expected, keys) + } +} + +func TestMap_IntegerKeys(t *testing.T) { + m := New[int, string]() + m.Set(3, "three") + m.Set(1, "one") + m.Set(2, "two") + + var keys []int + for k := range m.All() { + keys = append(keys, k) + } + + // Should preserve insertion order, not numerical order + expected := []int{3, 1, 2} + if !slices.Equal(keys, expected) { + t.Errorf("expected %v, got %v", expected, keys) + } +} + +func TestMap_UnmarshalIntoExisting(t *testing.T) { + m := New[string, int]() + m.Set("existing", 999) + + // Unmarshal should replace contents + if err := json.Unmarshal([]byte(`{"new":1}`), m); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + + _, ok := m.Get("existing") + if ok { + t.Error("existing key should be gone after unmarshal") + } + + v, ok := m.Get("new") + if !ok || v != 1 { + t.Errorf("expected Get(new) = (1, true), got (%d, %v)", v, ok) + } +} + +func TestMap_LargeOrderPreservation(t *testing.T) { + m := New[string, int]() + + // Create many keys in specific order + keys := make([]string, 100) + for i := range 100 { + keys[i] = string(rune('a' + (99 - i))) // reverse order: 'd', 'c', 'b', 'a' (extended) + if i >= 26 { + keys[i] = string(rune('A'+i-26)) + string(rune('a'+i%26)) + } + } + + for i, k := range keys { + m.Set(k, i) + } + + // Verify order preserved + var resultKeys []string + for k := range m.All() { + resultKeys = append(resultKeys, k) + } + + if !slices.Equal(keys, resultKeys) { + t.Error("large map should preserve insertion order") + } +} diff --git a/middleware/openai_test.go b/middleware/openai_test.go index fc71b57d2..3b8f5088a 100644 --- a/middleware/openai_test.go +++ b/middleware/openai_test.go @@ -19,6 +19,40 @@ import ( "github.com/ollama/ollama/openai" ) +// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests) +func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap { + props := api.NewToolPropertiesMap() + for k, v := range m { + props.Set(k, v) + } + return props +} + +// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests) +func testArgs(m map[string]any) api.ToolCallFunctionArguments { + args := api.NewToolCallFunctionArguments() + for k, v := range m { + args.Set(k, v) + } + return args +} + +// argsComparer provides cmp options for comparing ToolCallFunctionArguments by value +var argsComparer = cmp.Comparer(func(a, b api.ToolCallFunctionArguments) bool { + return cmp.Equal(a.ToMap(), b.ToMap()) +}) + +// propsComparer provides cmp options for comparing ToolPropertiesMap by value +var propsComparer = cmp.Comparer(func(a, b *api.ToolPropertiesMap) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + return cmp.Equal(a.ToMap(), b.ToMap()) +}) + const ( prefix = `data:image/jpeg;base64,` image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` @@ -221,10 +255,10 @@ func TestChatMiddleware(t *testing.T) { ID: "id", Function: api.ToolCallFunction{ Name: "get_current_weather", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "location": "Paris, France", "format": "celsius", - }, + }), }, }, }, @@ -261,10 +295,10 @@ func TestChatMiddleware(t *testing.T) { ID: "id", Function: api.ToolCallFunction{ Name: "get_current_weather", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "location": "Paris, France", "format": "celsius", - }, + }), }, }, }, @@ -300,10 +334,10 @@ func TestChatMiddleware(t *testing.T) { ID: "id", Function: api.ToolCallFunction{ Name: "get_current_weather", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "location": "Paris, France", "format": "celsius", - }, + }), }, }, }, @@ -340,10 +374,10 @@ func TestChatMiddleware(t *testing.T) { ID: "id", Function: api.ToolCallFunction{ Name: "get_current_weather", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "location": "Paris, France", "format": "celsius", - }, + }), }, }, }, @@ -380,10 +414,10 @@ func TestChatMiddleware(t *testing.T) { ID: "id_abc", Function: api.ToolCallFunction{ Name: "get_current_weather", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "location": "Paris, France", "format": "celsius", - }, + }), }, }, }, @@ -426,10 +460,10 @@ func TestChatMiddleware(t *testing.T) { ID: "id", Function: api.ToolCallFunction{ Name: "get_current_weather", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "location": "Paris, France", "format": "celsius", - }, + }), }, }, }, @@ -494,7 +528,7 @@ func TestChatMiddleware(t *testing.T) { Parameters: api.ToolFunctionParameters{ Type: "object", Required: []string{"location"}, - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "location": { Type: api.PropertyType{"string"}, Description: "The city and state", @@ -503,7 +537,7 @@ func TestChatMiddleware(t *testing.T) { Type: api.PropertyType{"string"}, Enum: []any{"celsius", "fahrenheit"}, }, - }, + }), }, }, }, @@ -558,7 +592,7 @@ func TestChatMiddleware(t *testing.T) { } return } - if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" { + if diff := cmp.Diff(&tc.req, capturedRequest, argsComparer, propsComparer); diff != "" { t.Fatalf("requests did not match: %+v", diff) } if diff := cmp.Diff(tc.err, errResp); diff != "" { diff --git a/model/parsers/cogito_test.go b/model/parsers/cogito_test.go index 7eaa1c2e2..932e1b9a6 100644 --- a/model/parsers/cogito_test.go +++ b/model/parsers/cogito_test.go @@ -40,9 +40,9 @@ func TestCogitoParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Paris", - }, + }), }, }, }, @@ -52,9 +52,9 @@ func TestCogitoParser(t *testing.T) { Function: api.ToolFunction{ Name: "get_weather", Parameters: api.ToolFunctionParameters{ - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "location": {Type: api.PropertyType{"string"}}, - }, + }), }, }, }, @@ -71,9 +71,9 @@ func TestCogitoParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Paris", - }, + }), }, }, }, @@ -83,9 +83,9 @@ func TestCogitoParser(t *testing.T) { Function: api.ToolFunction{ Name: "get_weather", Parameters: api.ToolFunctionParameters{ - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "location": {Type: api.PropertyType{"string"}}, - }, + }), }, }, }, @@ -103,17 +103,17 @@ func TestCogitoParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Paris", - }, + }), }, }, { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "London", - }, + }), }, }, }, @@ -123,9 +123,9 @@ func TestCogitoParser(t *testing.T) { Function: api.ToolFunction{ Name: "get_weather", Parameters: api.ToolFunctionParameters{ - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "location": {Type: api.PropertyType{"string"}}, - }, + }), }, }, }, @@ -140,11 +140,11 @@ func TestCogitoParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "process_data", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "items": []any{"item1", "item2"}, "config": map[string]any{"enabled": true, "threshold": 0.95}, "count": 42.0, - }, + }), }, }, }, @@ -238,7 +238,7 @@ This is line 3Final response here.`, t.Errorf("thinking mismatch (-want +got):\n%s", diff) } - if diff := cmp.Diff(tt.expectedToolCalls, toolCalls); diff != "" { + if diff := cmp.Diff(tt.expectedToolCalls, toolCalls, argsComparer); diff != "" { t.Errorf("tool calls mismatch (-want +got):\n%s", diff) } }) @@ -277,9 +277,9 @@ func TestCogitoParser_Streaming(t *testing.T) { { Function: api.ToolCallFunction{ Name: "test_tool", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "arg": "value", - }, + }), }, }, } @@ -292,7 +292,7 @@ func TestCogitoParser_Streaming(t *testing.T) { t.Errorf("expected thinking %q, got %q", expectedThinking, finalThinking.String()) } - if diff := cmp.Diff(expectedToolCalls, finalToolCalls); diff != "" { + if diff := cmp.Diff(expectedToolCalls, finalToolCalls, argsComparer); diff != "" { t.Errorf("tool calls mismatch (-want +got):\n%s", diff) } } @@ -367,7 +367,7 @@ func TestCogitoParser_StreamingEdgeCases(t *testing.T) { t.Errorf("expected thinking %q, got %q", tt.expectedThinking, finalThinking.String()) } - if diff := cmp.Diff(tt.expectedToolCalls, finalToolCalls); diff != "" { + if diff := cmp.Diff(tt.expectedToolCalls, finalToolCalls, argsComparer); diff != "" { t.Errorf("tool calls mismatch (-want +got):\n%s", diff) } }) @@ -412,9 +412,9 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) { expected: api.ToolCall{ Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Paris", - }, + }), }, }, expectError: false, @@ -427,11 +427,11 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) { expected: api.ToolCall{ Function: api.ToolCallFunction{ Name: "process_data", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "items": []any{"item1", "item2"}, "config": map[string]any{"enabled": true}, "count": 42.0, - }, + }), }, }, expectError: false, @@ -444,7 +444,7 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) { expected: api.ToolCall{ Function: api.ToolCallFunction{ Name: "no_args_tool", - Arguments: api.ToolCallFunctionArguments{}, + Arguments: api.NewToolCallFunctionArguments(), }, }, expectError: false, @@ -493,9 +493,9 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) { expected: api.ToolCall{ Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Paris", - }, + }), }, }, expectError: false, @@ -511,10 +511,10 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) { expected: api.ToolCall{ Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Paris", "units": "metric", - }, + }), }, }, expectError: false, @@ -527,13 +527,13 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) { expected: api.ToolCall{ Function: api.ToolCallFunction{ Name: "complex_tool", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "nested": map[string]any{ "deep": map[string]any{ "value": 123.0, }, }, - }, + }), }, }, expectError: false, @@ -557,7 +557,7 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) { t.Fatalf("unexpected error: %v", err) } - if diff := cmp.Diff(tt.expected, result); diff != "" { + if diff := cmp.Diff(tt.expected, result, argsComparer); diff != "" { t.Errorf("tool call mismatch (-want +got):\n%s", diff) } }) diff --git a/model/parsers/deepseek3_test.go b/model/parsers/deepseek3_test.go index 4e3180d47..d648300d7 100644 --- a/model/parsers/deepseek3_test.go +++ b/model/parsers/deepseek3_test.go @@ -51,9 +51,9 @@ func TestDeepSeekParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Paris", - }, + }), }, }, }, @@ -67,17 +67,17 @@ func TestDeepSeekParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Paris", - }, + }), }, }, { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "London", - }, + }), }, }, }, @@ -97,10 +97,10 @@ func TestDeepSeekParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "process_data", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "items": []interface{}{"item1", "item2"}, "config": map[string]interface{}{"enabled": true, "threshold": 0.95}, - }, + }), }, }, }, @@ -115,9 +115,9 @@ func TestDeepSeekParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Paris", - }, + }), }, }, }, @@ -162,9 +162,9 @@ func TestDeepSeekParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Tokyo", - }, + }), }, }, }, @@ -191,10 +191,10 @@ func TestDeepSeekParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "search", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "query": "北京天气", "language": "中文", - }, + }), }, }, }, @@ -220,10 +220,10 @@ func TestDeepSeekParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "execute_command", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "command": "ls && echo \"done\"", "path": "/home/user", - }, + }), }, }, }, @@ -244,7 +244,7 @@ func TestDeepSeekParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "ping", - Arguments: api.ToolCallFunctionArguments{}, + Arguments: api.NewToolCallFunctionArguments(), }, }, }, @@ -276,7 +276,7 @@ func TestDeepSeekParser(t *testing.T) { t.Errorf("Thinking mismatch (-want +got):\n%s", diff) } - if diff := cmp.Diff(tt.expectedCalls, calls); diff != "" { + if diff := cmp.Diff(tt.expectedCalls, calls, argsComparer); diff != "" { t.Errorf("Tool calls mismatch (-want +got):\n%s", diff) } }) @@ -313,9 +313,9 @@ func TestDeepSeekParser_Streaming(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Paris", - }, + }), }, }, }, @@ -342,7 +342,7 @@ func TestDeepSeekParser_Streaming(t *testing.T) { { Function: api.ToolCallFunction{ Name: "test", - Arguments: api.ToolCallFunctionArguments{}, + Arguments: api.NewToolCallFunctionArguments(), }, }, }, @@ -375,10 +375,10 @@ func TestDeepSeekParser_Streaming(t *testing.T) { { Function: api.ToolCallFunction{ Name: "calc", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "x": float64(42), "y": float64(24), - }, + }), }, }, }, @@ -414,7 +414,7 @@ func TestDeepSeekParser_Streaming(t *testing.T) { t.Errorf("Thinking mismatch (-want +got):\n%s", diff) } - if diff := cmp.Diff(tt.expectedCalls, allCalls); diff != "" { + if diff := cmp.Diff(tt.expectedCalls, allCalls, argsComparer); diff != "" { t.Errorf("Tool calls mismatch (-want +got):\n%s", diff) } }) @@ -469,7 +469,7 @@ func TestDeepSeekParser_Init(t *testing.T) { returnedTools := parser.Init(tools, nil, &api.ThinkValue{Value: true}) - if diff := cmp.Diff(tools, returnedTools); diff != "" { + if diff := cmp.Diff(tools, returnedTools, toolsComparer); diff != "" { t.Errorf("Init() returned tools mismatch (-want +got):\n%s", diff) } @@ -492,9 +492,9 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) { expected: api.ToolCall{ Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Paris", - }, + }), }, }, }, @@ -504,10 +504,10 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) { expected: api.ToolCall{ Function: api.ToolCallFunction{ Name: "process_data", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "items": []interface{}{"a", "b"}, "config": map[string]interface{}{"enabled": true}, - }, + }), }, }, }, @@ -517,7 +517,7 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) { expected: api.ToolCall{ Function: api.ToolCallFunction{ Name: "ping", - Arguments: api.ToolCallFunctionArguments{}, + Arguments: api.NewToolCallFunctionArguments(), }, }, }, @@ -527,9 +527,9 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) { expected: api.ToolCall{ Function: api.ToolCallFunction{ Name: "获取天气", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "城市": "北京", - }, + }), }, }, }, @@ -539,10 +539,10 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) { expected: api.ToolCall{ Function: api.ToolCallFunction{ Name: "execute", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "command": "ls && echo \"done\"", "path": "/home/user", - }, + }), }, }, }, @@ -552,11 +552,11 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) { expected: api.ToolCall{ Function: api.ToolCallFunction{ Name: "calculate", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "x": 3.14, "y": float64(42), "enabled": true, - }, + }), }, }, }, @@ -577,9 +577,9 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) { expected: api.ToolCall{ Function: api.ToolCallFunction{ Name: "", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "arg": "value", - }, + }), }, }, }, @@ -606,7 +606,7 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - if diff := cmp.Diff(tt.expected, result); diff != "" { + if diff := cmp.Diff(tt.expected, result, argsComparer); diff != "" { t.Errorf("parseToolCallContent() mismatch (-want +got):\n%s", diff) } }) diff --git a/model/parsers/functiongemma.go b/model/parsers/functiongemma.go index 35f8791cc..9d3df9edb 100644 --- a/model/parsers/functiongemma.go +++ b/model/parsers/functiongemma.go @@ -166,7 +166,7 @@ func (p *FunctionGemmaParser) parseToolCall(content string) (api.ToolCall, error // parseArguments parses the key:value,key:value format func (p *FunctionGemmaParser) parseArguments(argsStr string) api.ToolCallFunctionArguments { - args := make(api.ToolCallFunctionArguments) + args := api.NewToolCallFunctionArguments() if argsStr == "" { return args } @@ -185,7 +185,7 @@ func (p *FunctionGemmaParser) parseArguments(argsStr string) api.ToolCallFunctio value := part[colonIdx+1:] // Parse the value - args[key] = p.parseValue(value) + args.Set(key, p.parseValue(value)) } return args diff --git a/model/parsers/functiongemma_test.go b/model/parsers/functiongemma_test.go index 227abdb8f..092763019 100644 --- a/model/parsers/functiongemma_test.go +++ b/model/parsers/functiongemma_test.go @@ -3,6 +3,7 @@ package parsers import ( "testing" + "github.com/google/go-cmp/cmp" "github.com/ollama/ollama/api" "github.com/stretchr/testify/assert" ) @@ -36,9 +37,9 @@ func TestFunctionGemmaParser(t *testing.T) { Name: "get_weather", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "city": {Type: api.PropertyType{"string"}}, - }, + }), }, }, }, @@ -47,7 +48,7 @@ func TestFunctionGemmaParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{"city": "Paris"}, + Arguments: testArgs(map[string]any{"city": "Paris"}), }, }, }, @@ -66,7 +67,7 @@ func TestFunctionGemmaParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{"city": "Paris"}, + Arguments: testArgs(map[string]any{"city": "Paris"}), }, }, }, @@ -84,7 +85,7 @@ func TestFunctionGemmaParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "add", - Arguments: api.ToolCallFunctionArguments{"a": int64(1), "b": int64(2)}, + Arguments: testArgs(map[string]any{"a": int64(1), "b": int64(2)}), }, }, }, @@ -102,7 +103,7 @@ func TestFunctionGemmaParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "set_flag", - Arguments: api.ToolCallFunctionArguments{"enabled": true, "verbose": false}, + Arguments: testArgs(map[string]any{"enabled": true, "verbose": false}), }, }, }, @@ -124,13 +125,13 @@ func TestFunctionGemmaParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{"city": "Paris"}, + Arguments: testArgs(map[string]any{"city": "Paris"}), }, }, { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{"city": "London"}, + Arguments: testArgs(map[string]any{"city": "London"}), }, }, }, @@ -152,7 +153,7 @@ func TestFunctionGemmaParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "process", - Arguments: api.ToolCallFunctionArguments{"items": []any{"a", "b", "c"}}, + Arguments: testArgs(map[string]any{"items": []any{"a", "b", "c"}}), }, }, }, @@ -173,9 +174,9 @@ func TestFunctionGemmaParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "update", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "data": map[string]any{"name": "test", "value": int64(42)}, - }, + }), }, }, }, @@ -198,7 +199,7 @@ func TestFunctionGemmaParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_time", - Arguments: api.ToolCallFunctionArguments{}, + Arguments: api.NewToolCallFunctionArguments(), }, }, }, @@ -224,7 +225,7 @@ func TestFunctionGemmaParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "set_temp", - Arguments: api.ToolCallFunctionArguments{"value": 3.14}, + Arguments: testArgs(map[string]any{"value": 3.14}), }, }, }, @@ -242,7 +243,7 @@ func TestFunctionGemmaParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "test", - Arguments: api.ToolCallFunctionArguments{}, + Arguments: api.NewToolCallFunctionArguments(), }, }, }, @@ -261,7 +262,7 @@ func TestFunctionGemmaParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "greet", - Arguments: api.ToolCallFunctionArguments{"name": "日本語"}, + Arguments: testArgs(map[string]any{"name": "日本語"}), }, }, }, @@ -281,11 +282,11 @@ func TestFunctionGemmaParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "search", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "query": "test", "limit": int64(10), "offset": int64(0), - }, + }), }, }, }, @@ -308,14 +309,14 @@ func TestFunctionGemmaParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "create", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "config": map[string]any{ "settings": map[string]any{ "enabled": true, "name": "test", }, }, - }, + }), }, }, }, @@ -345,13 +346,13 @@ func TestFunctionGemmaParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{"city": "Paris"}, + Arguments: testArgs(map[string]any{"city": "Paris"}), }, }, { Function: api.ToolCallFunction{ Name: "get_time", - Arguments: api.ToolCallFunctionArguments{"timezone": "UTC"}, + Arguments: testArgs(map[string]any{"timezone": "UTC"}), }, }, }, @@ -372,13 +373,13 @@ func TestFunctionGemmaParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "first", - Arguments: api.ToolCallFunctionArguments{}, + Arguments: api.NewToolCallFunctionArguments(), }, }, { Function: api.ToolCallFunction{ Name: "second", - Arguments: api.ToolCallFunctionArguments{}, + Arguments: api.NewToolCallFunctionArguments(), }, }, }, @@ -411,7 +412,9 @@ func TestFunctionGemmaParser(t *testing.T) { } assert.Equal(t, tt.expectedText, allContent) - assert.Equal(t, tt.expectedCalls, allCalls) + if diff := cmp.Diff(tt.expectedCalls, allCalls, argsComparer); diff != "" { + t.Errorf("calls mismatch (-want +got):\n%s", diff) + } }) } } diff --git a/model/parsers/ministral.go b/model/parsers/ministral.go index fbb54ad2d..2acf10c5f 100644 --- a/model/parsers/ministral.go +++ b/model/parsers/ministral.go @@ -112,8 +112,8 @@ func (p *MinistralParser) Add(s string, done bool) (content string, thinking str before, _ := splitAtTag(&p.buffer, "}", false) before += "}" - var data map[string]any - if err := json.Unmarshal([]byte(before), &data); err != nil { + var args api.ToolCallFunctionArguments + if err := json.Unmarshal([]byte(before), &args); err != nil { // todo - throw a better error return "", "", calls, err } @@ -123,7 +123,7 @@ func (p *MinistralParser) Add(s string, done bool) (content string, thinking str call := api.ToolCall{ Function: api.ToolCallFunction{ Name: p.currentTool.Function.Name, - Arguments: api.ToolCallFunctionArguments(data), + Arguments: args, }, } calls = append(calls, call) diff --git a/model/parsers/nemotron3nano.go b/model/parsers/nemotron3nano.go index 6e662fba5..7fda8cdc7 100644 --- a/model/parsers/nemotron3nano.go +++ b/model/parsers/nemotron3nano.go @@ -225,7 +225,7 @@ func (p *Nemotron3NanoParser) parseToolCall(content string) (api.ToolCall, error toolCall.Function.Name = fnMatch[1] // Extract parameters - toolCall.Function.Arguments = make(api.ToolCallFunctionArguments) + toolCall.Function.Arguments = api.NewToolCallFunctionArguments() paramMatches := nemotronParameterRegex.FindAllStringSubmatch(content, -1) for _, match := range paramMatches { if len(match) >= 3 { @@ -233,7 +233,7 @@ func (p *Nemotron3NanoParser) parseToolCall(content string) (api.ToolCall, error paramValue := strings.TrimSpace(match[2]) // Try to parse as typed value based on tool definition - toolCall.Function.Arguments[paramName] = p.parseParamValue(paramName, paramValue) + toolCall.Function.Arguments.Set(paramName, p.parseParamValue(paramName, paramValue)) } } @@ -244,9 +244,11 @@ func (p *Nemotron3NanoParser) parseParamValue(paramName string, raw string) any // Find the matching tool to get parameter type var paramType api.PropertyType for _, tool := range p.tools { - if prop, ok := tool.Function.Parameters.Properties[paramName]; ok { - paramType = prop.Type - break + if tool.Function.Parameters.Properties != nil { + if prop, ok := tool.Function.Parameters.Properties.Get(paramName); ok { + paramType = prop.Type + break + } } } diff --git a/model/parsers/nemotron3nano_test.go b/model/parsers/nemotron3nano_test.go index a4517fc44..408a31e85 100644 --- a/model/parsers/nemotron3nano_test.go +++ b/model/parsers/nemotron3nano_test.go @@ -51,7 +51,7 @@ func TestNemotron3NanoParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: map[string]any{"city": "Paris"}, + Arguments: testArgs(map[string]any{"city": "Paris"}), }, }, }, @@ -65,7 +65,7 @@ func TestNemotron3NanoParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: map[string]any{"city": "NYC"}, + Arguments: testArgs(map[string]any{"city": "NYC"}), }, }, }, @@ -78,10 +78,10 @@ func TestNemotron3NanoParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "book_flight", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "from": "SFO", "to": "NYC", - }, + }), }, }, }, @@ -95,13 +95,13 @@ func TestNemotron3NanoParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: map[string]any{"city": "San Francisco"}, + Arguments: testArgs(map[string]any{"city": "San Francisco"}), }, }, { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: map[string]any{"city": "New York"}, + Arguments: testArgs(map[string]any{"city": "New York"}), }, }, }, @@ -115,7 +115,7 @@ func TestNemotron3NanoParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: map[string]any{"city": "Paris"}, + Arguments: testArgs(map[string]any{"city": "Paris"}), }, }, }, @@ -130,7 +130,7 @@ func TestNemotron3NanoParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "search", - Arguments: map[string]any{"query": "test"}, + Arguments: testArgs(map[string]any{"query": "test"}), }, }, }, @@ -143,7 +143,7 @@ func TestNemotron3NanoParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "create_note", - Arguments: map[string]any{"content": "Line 1\nLine 2\nLine 3"}, + Arguments: testArgs(map[string]any{"content": "Line 1\nLine 2\nLine 3"}), }, }, }, @@ -165,7 +165,7 @@ func TestNemotron3NanoParser(t *testing.T) { name: "tool call with no function name - returns empty tool call", input: "\n\n\n", thinkValue: nil, - expectedCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "", Arguments: nil}}}, + expectedCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "", Arguments: api.NewToolCallFunctionArguments()}}}, }, { name: "content with newlines preserved", @@ -194,7 +194,7 @@ func TestNemotron3NanoParser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "set_temp", - Arguments: map[string]any{"value": "42"}, + Arguments: testArgs(map[string]any{"value": "42"}), }, }, }, @@ -226,7 +226,7 @@ func TestNemotron3NanoParser(t *testing.T) { if diff := cmp.Diff(thinking, tt.expectedThinking); diff != "" { t.Errorf("thinking mismatch (-got +want):\n%s", diff) } - if diff := cmp.Diff(calls, tt.expectedCalls); diff != "" { + if diff := cmp.Diff(calls, tt.expectedCalls, argsComparer); diff != "" { t.Errorf("calls mismatch (-got +want):\n%s", diff) } }) @@ -276,7 +276,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: map[string]any{"city": "Paris"}, + Arguments: testArgs(map[string]any{"city": "Paris"}), }, }, }, @@ -290,7 +290,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: map[string]any{"city": "NYC"}, + Arguments: testArgs(map[string]any{"city": "NYC"}), }, }, }, @@ -302,7 +302,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) { { Function: api.ToolCallFunction{ Name: "test", - Arguments: map[string]any{}, + Arguments: api.NewToolCallFunctionArguments(), }, }, }, @@ -329,10 +329,10 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) { { Function: api.ToolCallFunction{ Name: "book_flight", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "from": "SFO", "to": "NYC", - }, + }), }, }, }, @@ -347,7 +347,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) { { Function: api.ToolCallFunction{ Name: "search", - Arguments: map[string]any{"query": "test query"}, + Arguments: testArgs(map[string]any{"query": "test query"}), }, }, }, @@ -367,13 +367,13 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: map[string]any{"city": "San Francisco"}, + Arguments: testArgs(map[string]any{"city": "San Francisco"}), }, }, { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: map[string]any{"city": "New York"}, + Arguments: testArgs(map[string]any{"city": "New York"}), }, }, }, @@ -386,7 +386,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) { { Function: api.ToolCallFunction{ Name: "create_note", - Arguments: map[string]any{"content": "Line 1\nLine 2\nLine 3"}, + Arguments: testArgs(map[string]any{"content": "Line 1\nLine 2\nLine 3"}), }, }, }, @@ -413,7 +413,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) { { Function: api.ToolCallFunction{ Name: "test", - Arguments: map[string]any{}, + Arguments: api.NewToolCallFunctionArguments(), }, }, }, @@ -426,7 +426,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) { { Function: api.ToolCallFunction{ Name: "test", - Arguments: map[string]any{"name": ""}, + Arguments: testArgs(map[string]any{"name": ""}), }, }, }, @@ -473,7 +473,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) { if diff := cmp.Diff(allThinking, tt.expectedThinking); diff != "" { t.Errorf("thinking mismatch (-got +want):\n%s", diff) } - if diff := cmp.Diff(allCalls, tt.expectedCalls); diff != "" { + if diff := cmp.Diff(allCalls, tt.expectedCalls, argsComparer); diff != "" { t.Errorf("calls mismatch (-got +want):\n%s", diff) } }) @@ -537,9 +537,9 @@ func TestNemotron3NanoParser_WithTools(t *testing.T) { Name: "get_weather", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "city": {Type: api.PropertyType{"string"}}, - }, + }), }, }, }, @@ -548,7 +548,7 @@ func TestNemotron3NanoParser_WithTools(t *testing.T) { p := &Nemotron3NanoParser{} returnedTools := p.Init(tools, nil, nil) - if diff := cmp.Diff(returnedTools, tools); diff != "" { + if diff := cmp.Diff(returnedTools, tools, toolsComparer); diff != "" { t.Errorf("tools mismatch (-got +want):\n%s", diff) } @@ -563,12 +563,12 @@ func TestNemotron3NanoParser_WithTools(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: map[string]any{"city": "Paris"}, + Arguments: testArgs(map[string]any{"city": "Paris"}), }, }, } - if diff := cmp.Diff(calls, expectedCalls); diff != "" { + if diff := cmp.Diff(calls, expectedCalls, argsComparer); diff != "" { t.Errorf("calls mismatch (-got +want):\n%s", diff) } } diff --git a/model/parsers/olmo3.go b/model/parsers/olmo3.go index ee4037a69..285a31f62 100644 --- a/model/parsers/olmo3.go +++ b/model/parsers/olmo3.go @@ -242,8 +242,8 @@ func parseOlmo3SingleFunctionCall(s string) (api.ToolCall, error) { // parseOlmo3Arguments parses comma-separated key=value pairs // Handles nested parentheses, brackets, braces, and quoted strings -func parseOlmo3Arguments(s string) (map[string]any, error) { - args := make(map[string]any) +func parseOlmo3Arguments(s string) (api.ToolCallFunctionArguments, error) { + args := api.NewToolCallFunctionArguments() s = strings.TrimSpace(s) if s == "" { return args, nil @@ -261,7 +261,7 @@ func parseOlmo3Arguments(s string) (map[string]any, error) { // Find the first = sign eqIdx := strings.Index(part, "=") if eqIdx == -1 { - return nil, fmt.Errorf("invalid argument format: %s", part) + return api.ToolCallFunctionArguments{}, fmt.Errorf("invalid argument format: %s", part) } key := strings.TrimSpace(part[:eqIdx]) @@ -269,10 +269,10 @@ func parseOlmo3Arguments(s string) (map[string]any, error) { value, err := parseOlmo3Value(valueStr) if err != nil { - return nil, fmt.Errorf("failed to parse value for %s: %w", key, err) + return api.ToolCallFunctionArguments{}, fmt.Errorf("failed to parse value for %s: %w", key, err) } - args[key] = value + args.Set(key, value) } return args, nil diff --git a/model/parsers/olmo3_test.go b/model/parsers/olmo3_test.go index 6c5b57b8b..1710e3bf3 100644 --- a/model/parsers/olmo3_test.go +++ b/model/parsers/olmo3_test.go @@ -28,7 +28,7 @@ func TestOlmo3Parser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: map[string]any{"location": "San Francisco"}, + Arguments: testArgs(map[string]any{"location": "San Francisco"}), }, }, }, @@ -41,7 +41,7 @@ func TestOlmo3Parser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: map[string]any{"location": "NYC"}, + Arguments: testArgs(map[string]any{"location": "NYC"}), }, }, }, @@ -53,11 +53,11 @@ func TestOlmo3Parser(t *testing.T) { { Function: api.ToolCallFunction{ Name: "book_flight", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "from": "SFO", "to": "NYC", "date": "2024-01-15", - }, + }), }, }, }, @@ -70,13 +70,13 @@ get_weather(location="New York")`, { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: map[string]any{"location": "San Francisco"}, + Arguments: testArgs(map[string]any{"location": "San Francisco"}), }, }, { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: map[string]any{"location": "New York"}, + Arguments: testArgs(map[string]any{"location": "New York"}), }, }, }, @@ -88,7 +88,7 @@ get_weather(location="New York")`, { Function: api.ToolCallFunction{ Name: "set_temperature", - Arguments: map[string]any{"value": int64(72)}, + Arguments: testArgs(map[string]any{"value": int64(72)}), }, }, }, @@ -100,7 +100,7 @@ get_weather(location="New York")`, { Function: api.ToolCallFunction{ Name: "set_price", - Arguments: map[string]any{"amount": 19.99}, + Arguments: testArgs(map[string]any{"amount": 19.99}), }, }, }, @@ -112,7 +112,7 @@ get_weather(location="New York")`, { Function: api.ToolCallFunction{ Name: "toggle_setting", - Arguments: map[string]any{"enabled": true}, + Arguments: testArgs(map[string]any{"enabled": true}), }, }, }, @@ -124,7 +124,7 @@ get_weather(location="New York")`, { Function: api.ToolCallFunction{ Name: "clear_value", - Arguments: map[string]any{"field": nil}, + Arguments: testArgs(map[string]any{"field": nil}), }, }, }, @@ -136,7 +136,7 @@ get_weather(location="New York")`, { Function: api.ToolCallFunction{ Name: "process_items", - Arguments: map[string]any{"items": []any{"apple", "banana", "cherry"}}, + Arguments: testArgs(map[string]any{"items": []any{"apple", "banana", "cherry"}}), }, }, }, @@ -148,12 +148,12 @@ get_weather(location="New York")`, { Function: api.ToolCallFunction{ Name: "update_config", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "settings": map[string]any{ "theme": "dark", "fontSize": int64(14), }, - }, + }), }, }, }, @@ -165,7 +165,7 @@ get_weather(location="New York")`, { Function: api.ToolCallFunction{ Name: "create_request", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "data": map[string]any{ "user": map[string]any{ "name": "John", @@ -173,7 +173,7 @@ get_weather(location="New York")`, }, "active": true, }, - }, + }), }, }, }, @@ -185,7 +185,7 @@ get_weather(location="New York")`, { Function: api.ToolCallFunction{ Name: "get_current_time", - Arguments: map[string]any{}, + Arguments: testArgs(map[string]any{}), }, }, }, @@ -197,7 +197,7 @@ get_weather(location="New York")`, { Function: api.ToolCallFunction{ Name: "search", - Arguments: map[string]any{"query": "hello world"}, + Arguments: testArgs(map[string]any{"query": "hello world"}), }, }, }, @@ -209,7 +209,7 @@ get_weather(location="New York")`, { Function: api.ToolCallFunction{ Name: "search", - Arguments: map[string]any{"query": `say "hello"`}, + Arguments: testArgs(map[string]any{"query": `say "hello"`}), }, }, }, @@ -221,11 +221,11 @@ get_weather(location="New York")`, { Function: api.ToolCallFunction{ Name: "create_user", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "name": "John", "age": int64(30), "active": true, - }, + }), }, }, }, @@ -257,7 +257,7 @@ get_weather(location="New York")`, if diff := cmp.Diff(thinking, tt.expectedThinking); diff != "" { t.Errorf("thinking mismatch (-got +want):\n%s", diff) } - if diff := cmp.Diff(calls, tt.expectedCalls); diff != "" { + if diff := cmp.Diff(calls, tt.expectedCalls, argsComparer); diff != "" { t.Errorf("calls mismatch (-got +want):\n%s", diff) } }) @@ -283,7 +283,7 @@ func TestOlmo3Parser_Streaming(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: map[string]any{"location": "SF"}, + Arguments: testArgs(map[string]any{"location": "SF"}), }, }, }, @@ -296,7 +296,7 @@ func TestOlmo3Parser_Streaming(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: map[string]any{"location": "NYC"}, + Arguments: testArgs(map[string]any{"location": "NYC"}), }, }, }, @@ -308,7 +308,7 @@ func TestOlmo3Parser_Streaming(t *testing.T) { { Function: api.ToolCallFunction{ Name: "test", - Arguments: map[string]any{}, + Arguments: testArgs(map[string]any{}), }, }, }, @@ -343,7 +343,7 @@ func TestOlmo3Parser_Streaming(t *testing.T) { if diff := cmp.Diff(allContent, tt.expectedContent); diff != "" { t.Errorf("content mismatch (-got +want):\n%s", diff) } - if diff := cmp.Diff(allCalls, tt.expectedCalls); diff != "" { + if diff := cmp.Diff(allCalls, tt.expectedCalls, argsComparer); diff != "" { t.Errorf("calls mismatch (-got +want):\n%s", diff) } }) @@ -378,7 +378,7 @@ func TestParseOlmo3FunctionCalls(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: map[string]any{"location": "SF"}, + Arguments: testArgs(map[string]any{"location": "SF"}), }, }, }, @@ -390,11 +390,11 @@ func TestParseOlmo3FunctionCalls(t *testing.T) { { Function: api.ToolCallFunction{ Name: "send_email", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "to": "user@example.com", "subject": "Hello", "body": "Test message", - }, + }), }, }, }, @@ -407,13 +407,13 @@ get_time(timezone="PST")`, { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: map[string]any{"location": "SF"}, + Arguments: testArgs(map[string]any{"location": "SF"}), }, }, { Function: api.ToolCallFunction{ Name: "get_time", - Arguments: map[string]any{"timezone": "PST"}, + Arguments: testArgs(map[string]any{"timezone": "PST"}), }, }, }, @@ -437,7 +437,7 @@ get_time(timezone="PST")`, t.Errorf("parseOlmo3FunctionCalls() error = %v, wantErr %v", err, tt.wantErr) return } - if diff := cmp.Diff(calls, tt.expected); diff != "" { + if diff := cmp.Diff(calls, tt.expected, argsComparer); diff != "" { t.Errorf("calls mismatch (-got +want):\n%s", diff) } }) diff --git a/model/parsers/qwen3coder.go b/model/parsers/qwen3coder.go index 9a073b1c4..cf8f214e2 100644 --- a/model/parsers/qwen3coder.go +++ b/model/parsers/qwen3coder.go @@ -270,12 +270,12 @@ func parseToolCall(raw qwenEventRawToolCall, tools []api.Tool) (api.ToolCall, er } } - toolCall.Function.Arguments = make(api.ToolCallFunctionArguments) + toolCall.Function.Arguments = api.NewToolCallFunctionArguments() for _, parameter := range functionCall.Parameters { // Look up the parameter type if we found the tool var paramType api.PropertyType if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil { - if prop, ok := matchedTool.Function.Parameters.Properties[parameter.Name]; ok { + if prop, ok := matchedTool.Function.Parameters.Properties.Get(parameter.Name); ok { // Handle anyOf by collecting all types from the union if len(prop.AnyOf) > 0 { for _, anyOfProp := range prop.AnyOf { @@ -287,7 +287,7 @@ func parseToolCall(raw qwenEventRawToolCall, tools []api.Tool) (api.ToolCall, er } } - toolCall.Function.Arguments[parameter.Name] = parseValue(parameter.Value, paramType) + toolCall.Function.Arguments.Set(parameter.Name, parseValue(parameter.Value, paramType)) } return toolCall, nil diff --git a/model/parsers/qwen3coder_test.go b/model/parsers/qwen3coder_test.go index e4246abcd..01c39924b 100644 --- a/model/parsers/qwen3coder_test.go +++ b/model/parsers/qwen3coder_test.go @@ -11,7 +11,7 @@ import ( func tool(name string, props map[string]api.ToolProperty) api.Tool { t := api.Tool{Type: "function", Function: api.ToolFunction{Name: name}} t.Function.Parameters.Type = "object" - t.Function.Parameters.Properties = props + t.Function.Parameters.Properties = testPropsMap(props) return t } @@ -369,10 +369,10 @@ celsius wantToolCall: api.ToolCall{ Function: api.ToolCallFunction{ Name: "get_current_temperature", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "location": "San Francisco", "unit": "celsius", - }, + }), }, }, }, @@ -390,10 +390,10 @@ celsius wantToolCall: api.ToolCall{ Function: api.ToolCallFunction{ Name: "get current temperature", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "location with spaces": "San Francisco", "unit with spaces": "celsius", - }, + }), }, }, }, @@ -415,10 +415,10 @@ San Francisco wantToolCall: api.ToolCall{ Function: api.ToolCallFunction{ Name: "\"get current temperature\"", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "\"location with spaces\"": "San Francisco", "\"unit with spaces\"": "\"celsius\"", - }, + }), }, }, }, @@ -449,12 +449,12 @@ true wantToolCall: api.ToolCall{ Function: api.ToolCallFunction{ Name: "calculate", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "x": 3.14, "y": 42, "enabled": true, "items": []any{"a", "b", "c"}, - }, + }), }, }, }, @@ -470,9 +470,9 @@ ls && echo "done" wantToolCall: api.ToolCall{ Function: api.ToolCallFunction{ Name: "exec", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "command": "ls && echo \"done\"", - }, + }), }, }, }, @@ -487,9 +487,9 @@ ls && echo "a > b and a < b" wantToolCall: api.ToolCall{ Function: api.ToolCallFunction{ Name: "exec", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "command": "ls && echo \"a > b and a < b\"", - }, + }), }, }, }, @@ -507,10 +507,10 @@ Hello! 你好! 🌟 مرحبا wantToolCall: api.ToolCall{ Function: api.ToolCallFunction{ Name: "获取天气", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "城市": "北京", "message": "Hello! 你好! 🌟 مرحبا", - }, + }), }, }, }, @@ -521,7 +521,7 @@ Hello! 你好! 🌟 مرحبا if err != nil { t.Errorf("step %d (%s): %v", i, step.name, err) } - if !reflect.DeepEqual(gotToolCall, step.wantToolCall) { + if !toolCallEqual(gotToolCall, step.wantToolCall) { t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall) } } diff --git a/model/parsers/qwen3vl_nonthinking_test.go b/model/parsers/qwen3vl_nonthinking_test.go index 803824a68..9b1129d98 100644 --- a/model/parsers/qwen3vl_nonthinking_test.go +++ b/model/parsers/qwen3vl_nonthinking_test.go @@ -550,10 +550,10 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) { wantToolCall: api.ToolCall{ Function: api.ToolCallFunction{ Name: "get-current-weather", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "location": "San Francisco, CA", "unit": "fahrenheit", - }, + }), }, }, }, @@ -564,10 +564,10 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) { wantToolCall: api.ToolCall{ Function: api.ToolCallFunction{ Name: "get current temperature", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "location with spaces": "San Francisco", "unit with spaces": "celsius", - }, + }), }, }, }, @@ -578,10 +578,10 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) { wantToolCall: api.ToolCall{ Function: api.ToolCallFunction{ Name: "\"get current temperature\"", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "\"location with spaces\"": "San Francisco", "\"unit with spaces\"": "\"celsius\"", - }, + }), }, }, }, @@ -592,12 +592,12 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) { wantToolCall: api.ToolCall{ Function: api.ToolCallFunction{ Name: "calculate", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "x": 3.14, "y": float64(42), "enabled": true, "items": []any{"a", "b", "c"}, - }, + }), }, }, }, @@ -608,9 +608,9 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) { wantToolCall: api.ToolCall{ Function: api.ToolCallFunction{ Name: "exec", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "command": "ls && echo \"done\"", - }, + }), }, }, }, @@ -621,9 +621,9 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) { wantToolCall: api.ToolCall{ Function: api.ToolCallFunction{ Name: "exec", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "command": "ls && echo \"a > b and a < b\"", - }, + }), }, }, }, @@ -634,10 +634,10 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) { wantToolCall: api.ToolCall{ Function: api.ToolCallFunction{ Name: "获取天气", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "城市": "北京", "message": "Hello! 你好! 🌟 مرحبا", - }, + }), }, }, }, @@ -648,7 +648,7 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) { if err != nil { t.Errorf("step %d (%s): %v", i, step.name, err) } - if !reflect.DeepEqual(gotToolCall, step.wantToolCall) { + if !toolCallEqual(gotToolCall, step.wantToolCall) { t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall) } } diff --git a/model/parsers/qwen3vl_thinking_test.go b/model/parsers/qwen3vl_thinking_test.go index 2d2424d20..ff3dc1683 100644 --- a/model/parsers/qwen3vl_thinking_test.go +++ b/model/parsers/qwen3vl_thinking_test.go @@ -241,10 +241,10 @@ func TestQwen3VLThinkingToolParser(t *testing.T) { wantToolCall: api.ToolCall{ Function: api.ToolCallFunction{ Name: "get-current-weather", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "location": "San Francisco, CA", "unit": "fahrenheit", - }, + }), }, }, }, @@ -255,10 +255,10 @@ func TestQwen3VLThinkingToolParser(t *testing.T) { wantToolCall: api.ToolCall{ Function: api.ToolCallFunction{ Name: "get current temperature", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "location with spaces": "San Francisco", "unit with spaces": "celsius", - }, + }), }, }, }, @@ -269,10 +269,10 @@ func TestQwen3VLThinkingToolParser(t *testing.T) { wantToolCall: api.ToolCall{ Function: api.ToolCallFunction{ Name: "\"get current temperature\"", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "\"location with spaces\"": "San Francisco", "\"unit with spaces\"": "\"celsius\"", - }, + }), }, }, }, @@ -283,12 +283,12 @@ func TestQwen3VLThinkingToolParser(t *testing.T) { wantToolCall: api.ToolCall{ Function: api.ToolCallFunction{ Name: "calculate", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "x": 3.14, "y": float64(42), "enabled": true, "items": []any{"a", "b", "c"}, - }, + }), }, }, }, @@ -299,9 +299,9 @@ func TestQwen3VLThinkingToolParser(t *testing.T) { wantToolCall: api.ToolCall{ Function: api.ToolCallFunction{ Name: "exec", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "command": "ls && echo \"done\"", - }, + }), }, }, }, @@ -312,9 +312,9 @@ func TestQwen3VLThinkingToolParser(t *testing.T) { wantToolCall: api.ToolCall{ Function: api.ToolCallFunction{ Name: "exec", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "command": "ls && echo \"a > b and a < b\"", - }, + }), }, }, }, @@ -325,10 +325,10 @@ func TestQwen3VLThinkingToolParser(t *testing.T) { wantToolCall: api.ToolCall{ Function: api.ToolCallFunction{ Name: "获取天气", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "城市": "北京", "message": "Hello! 你好! 🌟 مرحبا", - }, + }), }, }, }, @@ -339,7 +339,7 @@ func TestQwen3VLThinkingToolParser(t *testing.T) { if err != nil { t.Errorf("step %d (%s): %v", i, step.name, err) } - if !reflect.DeepEqual(gotToolCall, step.wantToolCall) { + if !toolCallEqual(gotToolCall, step.wantToolCall) { t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall) } } diff --git a/model/parsers/testhelpers_test.go b/model/parsers/testhelpers_test.go new file mode 100644 index 000000000..0c252be83 --- /dev/null +++ b/model/parsers/testhelpers_test.go @@ -0,0 +1,98 @@ +package parsers + +import ( + "encoding/json" + + "github.com/google/go-cmp/cmp" + "github.com/ollama/ollama/api" +) + +// argsComparer provides cmp options for comparing ToolCallFunctionArguments +// It compares by logical equality (same keys with same values) not by order +var argsComparer = cmp.Comparer(func(a, b api.ToolCallFunctionArguments) bool { + // Convert both to maps and compare + aMap := a.ToMap() + bMap := b.ToMap() + if len(aMap) != len(bMap) { + return false + } + for k, av := range aMap { + bv, ok := bMap[k] + if !ok { + return false + } + // Use JSON encoding for deep comparison of values + aJSON, _ := json.Marshal(av) + bJSON, _ := json.Marshal(bv) + if string(aJSON) != string(bJSON) { + return false + } + } + return true +}) + +// propsComparer provides cmp options for comparing ToolPropertiesMap +var propsComparer = cmp.Comparer(func(a, b *api.ToolPropertiesMap) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + aJSON, _ := json.Marshal(a) + bJSON, _ := json.Marshal(b) + return string(aJSON) == string(bJSON) +}) + +// toolsComparer combines argsComparer and propsComparer for comparing tools +var toolsComparer = cmp.Options{argsComparer, propsComparer} + +// toolCallEqual compares two tool calls by comparing their components +// It compares arguments by logical equality (same keys with same values) not by order +func toolCallEqual(a, b api.ToolCall) bool { + if a.ID != b.ID { + return false + } + if a.Function.Index != b.Function.Index { + return false + } + if a.Function.Name != b.Function.Name { + return false + } + // Compare arguments by logical equality using argsComparer logic + aMap := a.Function.Arguments.ToMap() + bMap := b.Function.Arguments.ToMap() + if len(aMap) != len(bMap) { + return false + } + for k, av := range aMap { + bv, ok := bMap[k] + if !ok { + return false + } + aJSON, _ := json.Marshal(av) + bJSON, _ := json.Marshal(bv) + if string(aJSON) != string(bJSON) { + return false + } + } + return true +} + +// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved) +func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap { + props := api.NewToolPropertiesMap() + for k, v := range m { + props.Set(k, v) + } + return props +} + +// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests, order not preserved) +func testArgs(m map[string]any) api.ToolCallFunctionArguments { + args := api.NewToolCallFunctionArguments() + for k, v := range m { + args.Set(k, v) + } + return args +} diff --git a/model/renderers/cogito_test.go b/model/renderers/cogito_test.go index 2b472502e..ea169f8e4 100644 --- a/model/renderers/cogito_test.go +++ b/model/renderers/cogito_test.go @@ -94,12 +94,12 @@ You are a helpful assistant. Description: "Get current weather", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "location": { Type: api.PropertyType{"string"}, Description: "City name", }, - }, + }), Required: []string{"location"}, }, }, @@ -139,9 +139,9 @@ You have the following functions available: { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Paris", - }, + }), }, }, }, @@ -162,9 +162,9 @@ You have the following functions available: { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Paris", - }, + }), }, }, }, @@ -186,17 +186,17 @@ You have the following functions available: { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Paris", - }, + }), }, }, { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "London", - }, + }), }, }, }, @@ -226,12 +226,12 @@ You have the following functions available: Description: "Get current weather", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "location": { Type: api.PropertyType{"string"}, Description: "City name", }, - }, + }), Required: []string{"location"}, }, }, @@ -378,9 +378,9 @@ You are a pirate chatbot who always responds in pirate speak! { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Paris", - }, + }), }, }, }, @@ -401,14 +401,14 @@ You are a pirate chatbot who always responds in pirate speak! { Function: api.ToolCallFunction{ Name: "process_data", - Arguments: api.ToolCallFunctionArguments{ - "items": []any{"item1", "item2", "item3"}, - "config": map[string]any{ + Arguments: testArgsOrdered([]orderedArg{ + {"config", map[string]any{ "enabled": true, "threshold": 0.95, "tags": []string{"important", "urgent"}, - }, - }, + }}, + {"items", []any{"item1", "item2", "item3"}}, + }), }, }, }, diff --git a/model/renderers/deepseek3_test.go b/model/renderers/deepseek3_test.go index c43a9f93a..913e9bcec 100644 --- a/model/renderers/deepseek3_test.go +++ b/model/renderers/deepseek3_test.go @@ -82,9 +82,9 @@ Second instruction<|User|>Hello<|Assistant|>`, { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Paris", - }, + }), }, }, }, @@ -104,9 +104,9 @@ Second instruction<|User|>Hello<|Assistant|>`, { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Paris", - }, + }), }, }, }, @@ -125,9 +125,9 @@ Second instruction<|User|>Hello<|Assistant|>`, { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Paris", - }, + }), }, }, }, @@ -147,17 +147,17 @@ Second instruction<|User|>Hello<|Assistant|>`, { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Paris", - }, + }), }, }, { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "London", - }, + }), }, }, }, @@ -214,9 +214,9 @@ Second instruction<|User|>Hello<|Assistant|>`, { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Paris", - }, + }), }, }, }, @@ -235,9 +235,9 @@ Second instruction<|User|>Hello<|Assistant|>`, { Function: api.ToolCallFunction{ Name: "process", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "data": "test", - }, + }), }, }, }, @@ -281,9 +281,9 @@ Second instruction<|User|>Hello<|Assistant|>`, { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Paris", - }, + }), }, }, }, @@ -305,9 +305,9 @@ Second instruction<|User|>Hello<|Assistant|>`, { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Paris", - }, + }), }, }, }, @@ -355,9 +355,9 @@ Second instruction<|User|>Hello<|Assistant|>`, { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Paris", - }, + }), }, }, }, @@ -379,9 +379,9 @@ Second instruction<|User|>Hello<|Assistant|>`, { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Paris", - }, + }), }, }, }, @@ -436,17 +436,17 @@ Second instruction<|User|>Hello<|Assistant|>`, { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Tokyo", - }, + }), }, }, { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "New York", - }, + }), }, }, }, @@ -489,12 +489,12 @@ Second instruction<|User|>Hello<|Assistant|>`, Description: "Get current weather information", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "location": { Type: api.PropertyType{"string"}, Description: "City name", }, - }, + }), Required: []string{"location"}, }, }, @@ -535,12 +535,12 @@ Where: Description: "Get current weather information", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "location": { Type: api.PropertyType{"string"}, Description: "City name", }, - }, + }), Required: []string{"location"}, }, }, @@ -578,9 +578,9 @@ Where: { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Paris", - }, + }), }, }, }, @@ -594,12 +594,12 @@ Where: Description: "Get current weather information", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "location": { Type: api.PropertyType{"string"}, Description: "City name", }, - }, + }), Required: []string{"location"}, }, }, @@ -638,9 +638,9 @@ Where: { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Paris", - }, + }), }, }, }, @@ -656,12 +656,12 @@ Where: Description: "Get current weather information", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "location": { Type: api.PropertyType{"string"}, Description: "City name", }, - }, + }), Required: []string{"location"}, }, }, @@ -701,9 +701,9 @@ Where: { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Tokyo", - }, + }), }, }, }, @@ -724,12 +724,12 @@ Where: Description: "Get current weather information", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "location": { Type: api.PropertyType{"string"}, Description: "City name", }, - }, + }), Required: []string{"location"}, }, }, @@ -770,12 +770,12 @@ Where: Description: "Get current weather information", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "location": { Type: api.PropertyType{"string"}, Description: "City name", }, - }, + }), Required: []string{"location"}, }, }, @@ -787,12 +787,12 @@ Where: Description: "Perform mathematical calculations", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "expression": { Type: api.PropertyType{"string"}, Description: "Mathematical expression to evaluate", }, - }, + }), Required: []string{"expression"}, }, }, @@ -834,17 +834,17 @@ Where: { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Paris", - }, + }), }, }, { Function: api.ToolCallFunction{ Name: "calculate", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "expression": "25 * 4", - }, + }), }, }, }, @@ -860,12 +860,12 @@ Where: Description: "Get current weather information", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "location": { Type: api.PropertyType{"string"}, Description: "City name", }, - }, + }), Required: []string{"location"}, }, }, @@ -877,12 +877,12 @@ Where: Description: "Perform mathematical calculations", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "expression": { Type: api.PropertyType{"string"}, Description: "Mathematical expression to evaluate", }, - }, + }), Required: []string{"expression"}, }, }, @@ -927,12 +927,12 @@ Where: Description: "Get current weather information", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "location": { Type: api.PropertyType{"string"}, Description: "City name", }, - }, + }), Required: []string{"location"}, }, }, diff --git a/model/renderers/functiongemma.go b/model/renderers/functiongemma.go index dcbcc0626..e767c84bc 100644 --- a/model/renderers/functiongemma.go +++ b/model/renderers/functiongemma.go @@ -136,7 +136,7 @@ func (r *FunctionGemmaRenderer) renderToolDeclaration(tool api.Tool) string { needsComma := false // Only include properties:{} if there are actual properties - if len(fn.Parameters.Properties) > 0 { + if fn.Parameters.Properties != nil && fn.Parameters.Properties.Len() > 0 { sb.WriteString("properties:{") r.writeProperties(&sb, fn.Parameters.Properties) sb.WriteString("}") @@ -172,16 +172,16 @@ func (r *FunctionGemmaRenderer) renderToolDeclaration(tool api.Tool) string { return sb.String() } -func (r *FunctionGemmaRenderer) writeProperties(sb *strings.Builder, props map[string]api.ToolProperty) { - keys := make([]string, 0, len(props)) - for k := range props { +func (r *FunctionGemmaRenderer) writeProperties(sb *strings.Builder, props *api.ToolPropertiesMap) { + keys := make([]string, 0, props.Len()) + for k := range props.All() { keys = append(keys, k) } sort.Strings(keys) first := true for _, name := range keys { - prop := props[name] + prop, _ := props.Get(name) if !first { sb.WriteString(",") } @@ -203,15 +203,15 @@ func (r *FunctionGemmaRenderer) formatToolCall(tc api.ToolCall) string { var sb strings.Builder sb.WriteString("call:" + tc.Function.Name + "{") - keys := make([]string, 0, len(tc.Function.Arguments)) - for k := range tc.Function.Arguments { + keys := make([]string, 0, tc.Function.Arguments.Len()) + for k := range tc.Function.Arguments.All() { keys = append(keys, k) } sort.Strings(keys) first := true for _, key := range keys { - value := tc.Function.Arguments[key] + value, _ := tc.Function.Arguments.Get(key) if !first { sb.WriteString(",") } diff --git a/model/renderers/functiongemma_test.go b/model/renderers/functiongemma_test.go index 733ff3744..fe9bd54e7 100644 --- a/model/renderers/functiongemma_test.go +++ b/model/renderers/functiongemma_test.go @@ -51,9 +51,9 @@ func TestFunctionGemmaRenderer(t *testing.T) { Description: "Get weather", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "city": {Type: api.PropertyType{"string"}, Description: "City"}, - }, + }), }, }, }, @@ -75,9 +75,9 @@ func TestFunctionGemmaRenderer(t *testing.T) { Description: "Get weather", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "city": {Type: api.PropertyType{"string"}, Description: "City"}, - }, + }), }, }, }, @@ -107,9 +107,9 @@ func TestFunctionGemmaRenderer(t *testing.T) { Description: "Get weather", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "city": {Type: api.PropertyType{"string"}, Description: "City"}, - }, + }), }, }, }, @@ -126,7 +126,7 @@ func TestFunctionGemmaRenderer(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{"city": "Paris"}, + Arguments: testArgs(map[string]any{"city": "Paris"}), }, }, }, @@ -141,9 +141,9 @@ func TestFunctionGemmaRenderer(t *testing.T) { Description: "Get weather", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "city": {Type: api.PropertyType{"string"}, Description: "City"}, - }, + }), }, }, }, @@ -161,7 +161,7 @@ func TestFunctionGemmaRenderer(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{"city": "Paris"}, + Arguments: testArgs(map[string]any{"city": "Paris"}), }, }, }, @@ -176,9 +176,9 @@ func TestFunctionGemmaRenderer(t *testing.T) { Description: "Get weather", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "city": {Type: api.PropertyType{"string"}, Description: "City"}, - }, + }), }, }, }, @@ -195,7 +195,7 @@ func TestFunctionGemmaRenderer(t *testing.T) { { Function: api.ToolCallFunction{ Name: "add", - Arguments: api.ToolCallFunctionArguments{"a": float64(1), "b": float64(2)}, + Arguments: testArgs(map[string]any{"a": float64(1), "b": float64(2)}), }, }, }, @@ -210,10 +210,10 @@ func TestFunctionGemmaRenderer(t *testing.T) { Description: "Add numbers", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "a": {Type: api.PropertyType{"number"}}, "b": {Type: api.PropertyType{"number"}}, - }, + }), }, }, }, @@ -239,10 +239,10 @@ func TestFunctionGemmaRenderer(t *testing.T) { Parameters: api.ToolFunctionParameters{ Type: "object", Required: []string{"city"}, - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "city": {Type: api.PropertyType{"string"}, Description: "City Name"}, "country": {Type: api.PropertyType{"string"}, Description: "Country Name"}, - }, + }), }, }, }, @@ -263,9 +263,9 @@ func TestFunctionGemmaRenderer(t *testing.T) { Description: "Get weather", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "city": {Type: api.PropertyType{"string"}, Description: "City"}, - }, + }), }, }, }, @@ -276,9 +276,9 @@ func TestFunctionGemmaRenderer(t *testing.T) { Description: "Get current time", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "timezone": {Type: api.PropertyType{"string"}, Description: "Timezone"}, - }, + }), }, }, }, @@ -296,13 +296,13 @@ func TestFunctionGemmaRenderer(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{"city": "Paris"}, + Arguments: testArgs(map[string]any{"city": "Paris"}), }, }, { Function: api.ToolCallFunction{ Name: "get_time", - Arguments: api.ToolCallFunctionArguments{"timezone": "UTC"}, + Arguments: testArgs(map[string]any{"timezone": "UTC"}), }, }, }, @@ -318,9 +318,9 @@ func TestFunctionGemmaRenderer(t *testing.T) { Description: "Get weather", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "city": {Type: api.PropertyType{"string"}, Description: "City"}, - }, + }), }, }, }, @@ -331,9 +331,9 @@ func TestFunctionGemmaRenderer(t *testing.T) { Description: "Get current time", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "timezone": {Type: api.PropertyType{"string"}, Description: "Timezone"}, - }, + }), }, }, }, @@ -351,7 +351,7 @@ func TestFunctionGemmaRenderer(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{"city": "Paris"}, + Arguments: testArgs(map[string]any{"city": "Paris"}), }, }, }, @@ -367,9 +367,9 @@ func TestFunctionGemmaRenderer(t *testing.T) { Description: "Get weather", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "city": {Type: api.PropertyType{"string"}, Description: "City"}, - }, + }), }, }, }, @@ -391,7 +391,7 @@ func TestFunctionGemmaRenderer(t *testing.T) { Description: "", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{}, + Properties: testPropsMap(map[string]api.ToolProperty{}), }, }, }, @@ -430,7 +430,7 @@ func TestFunctionGemmaRenderer(t *testing.T) { { Function: api.ToolCallFunction{ Name: "set_flag", - Arguments: api.ToolCallFunctionArguments{"enabled": true}, + Arguments: testArgs(map[string]any{"enabled": true}), }, }, }, @@ -445,9 +445,9 @@ func TestFunctionGemmaRenderer(t *testing.T) { Description: "Set a flag", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "enabled": {Type: api.PropertyType{"boolean"}, Description: "Flag value"}, - }, + }), }, }, }, @@ -468,11 +468,11 @@ func TestFunctionGemmaRenderer(t *testing.T) { Parameters: api.ToolFunctionParameters{ Type: "object", Required: []string{"a", "b", "c"}, - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "a": {Type: api.PropertyType{"string"}, Description: "A"}, "b": {Type: api.PropertyType{"string"}, Description: "B"}, "c": {Type: api.PropertyType{"string"}, Description: "C"}, - }, + }), }, }, }, @@ -492,9 +492,9 @@ func TestFunctionGemmaRenderer(t *testing.T) { Description: "Test", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "items": {Type: api.PropertyType{"array"}, Description: "List of items"}, - }, + }), }, }, }, diff --git a/model/renderers/nemotron3nano.go b/model/renderers/nemotron3nano.go index 478a59bdd..df847b48f 100644 --- a/model/renderers/nemotron3nano.go +++ b/model/renderers/nemotron3nano.go @@ -114,7 +114,7 @@ func (r *Nemotron3NanoRenderer) renderTools(tools []api.Tool) string { sb.WriteString("\n") if fn.Parameters.Properties != nil { - for paramName, paramFields := range fn.Parameters.Properties { + for paramName, paramFields := range fn.Parameters.Properties.All() { sb.WriteString("\n") sb.WriteString("\n" + paramName + "") @@ -202,7 +202,7 @@ func (r *Nemotron3NanoRenderer) formatContent(content string, truncate bool, add func (r *Nemotron3NanoRenderer) writeToolCalls(sb *strings.Builder, toolCalls []api.ToolCall) { for _, tc := range toolCalls { sb.WriteString("\n\n") - for name, value := range tc.Function.Arguments { + for name, value := range tc.Function.Arguments.All() { sb.WriteString("\n" + r.formatArgValue(value) + "\n\n") } sb.WriteString("\n\n") diff --git a/model/renderers/nemotron3nano_test.go b/model/renderers/nemotron3nano_test.go index ca1feb931..db8329fa7 100644 --- a/model/renderers/nemotron3nano_test.go +++ b/model/renderers/nemotron3nano_test.go @@ -75,9 +75,9 @@ func TestNemotron3NanoRenderer(t *testing.T) { Parameters: api.ToolFunctionParameters{ Type: "object", Required: []string{"city"}, - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "city": {Type: api.PropertyType{"string"}, Description: "The city name"}, - }, + }), }, }, }, @@ -113,7 +113,7 @@ func TestNemotron3NanoRenderer(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: map[string]any{"city": "Paris"}, + Arguments: testArgs(map[string]any{"city": "Paris"}), }, }, }, @@ -129,9 +129,9 @@ func TestNemotron3NanoRenderer(t *testing.T) { Parameters: api.ToolFunctionParameters{ Type: "object", Required: []string{"city"}, - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "city": {Type: api.PropertyType{"string"}, Description: "The city name"}, - }, + }), }, }, }, @@ -171,7 +171,7 @@ func TestNemotron3NanoRenderer(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: map[string]any{"city": "Paris"}, + Arguments: testArgs(map[string]any{"city": "Paris"}), }, }, }, @@ -185,9 +185,9 @@ func TestNemotron3NanoRenderer(t *testing.T) { Name: "get_weather", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "city": {Type: api.PropertyType{"string"}}, - }, + }), }, }, }, @@ -238,13 +238,13 @@ func TestNemotron3NanoRenderer(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: map[string]any{"city": "Paris"}, + Arguments: testArgs(map[string]any{"city": "Paris"}), }, }, { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: map[string]any{"city": "London"}, + Arguments: testArgs(map[string]any{"city": "London"}), }, }, }, @@ -259,9 +259,9 @@ func TestNemotron3NanoRenderer(t *testing.T) { Name: "get_weather", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "city": {Type: api.PropertyType{"string"}}, - }, + }), }, }, }, @@ -304,13 +304,13 @@ func TestNemotron3NanoRenderer(t *testing.T) { msgs: []api.Message{ {Role: "user", Content: "What's the weather in Paris and London? Also, what's 2+2?"}, {Role: "assistant", Content: "", Thinking: "I need to check the weather for both cities and calculate 2+2. Let me start with the weather calls.", ToolCalls: []api.ToolCall{ - {Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "Paris"}}}, - {Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "London"}}}, + {Function: api.ToolCallFunction{Name: "get_weather", Arguments: testArgs(map[string]any{"city": "Paris"})}}, + {Function: api.ToolCallFunction{Name: "get_weather", Arguments: testArgs(map[string]any{"city": "London"})}}, }}, {Role: "tool", Content: "Sunny, 22°C", ToolCallID: "call1"}, {Role: "tool", Content: "Rainy, 15°C", ToolCallID: "call2"}, {Role: "assistant", Content: "", Thinking: "Now I have the weather data. Let me calculate 2+2.", ToolCalls: []api.ToolCall{ - {Function: api.ToolCallFunction{Name: "calculate", Arguments: api.ToolCallFunctionArguments{"expression": "2+2"}}}, + {Function: api.ToolCallFunction{Name: "calculate", Arguments: testArgs(map[string]any{"expression": "2+2"})}}, }}, {Role: "tool", Content: "4", ToolCallID: "call3"}, {Role: "assistant", Content: "Based on the weather data, Paris is sunny at 22°C and London is rainy at 15°C. Also, 2+2 equals 4.", Thinking: "Perfect! I have all the information needed to provide a complete answer."}, @@ -322,9 +322,9 @@ func TestNemotron3NanoRenderer(t *testing.T) { Name: "get_weather", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "city": {Type: api.PropertyType{"string"}}, - }, + }), }, }, }, @@ -334,9 +334,9 @@ func TestNemotron3NanoRenderer(t *testing.T) { Name: "calculate", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "expression": {Type: api.PropertyType{"string"}}, - }, + }), }, }, }, @@ -389,7 +389,7 @@ func TestNemotron3NanoRenderer(t *testing.T) { { Role: "assistant", ToolCalls: []api.ToolCall{ - {Function: api.ToolCallFunction{Name: "get_user", Arguments: map[string]any{"id": "123"}}}, + {Function: api.ToolCallFunction{Name: "get_user", Arguments: testArgs(map[string]any{"id": "123"})}}, }, }, {Role: "tool", Content: `{"name": "John", "age": 30, "active": true}`}, @@ -401,7 +401,7 @@ func TestNemotron3NanoRenderer(t *testing.T) { Name: "get_user", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{"id": {Type: api.PropertyType{"string"}}}, + Properties: testPropsMap(map[string]api.ToolProperty{"id": {Type: api.PropertyType{"string"}}}), }, }, }, @@ -450,9 +450,9 @@ func TestNemotron3NanoRenderer(t *testing.T) { ToolCalls: []api.ToolCall{ {Function: api.ToolCallFunction{ Name: "create", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "data": map[string]any{"nested": "value", "count": 42}, - }, + }), }}, }, }, @@ -465,7 +465,7 @@ func TestNemotron3NanoRenderer(t *testing.T) { Name: "create", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{"data": {Type: api.PropertyType{"object"}}}, + Properties: testPropsMap(map[string]api.ToolProperty{"data": {Type: api.PropertyType{"object"}}}), }, }, }, @@ -512,7 +512,7 @@ func TestNemotron3NanoRenderer(t *testing.T) { { Role: "assistant", ToolCalls: []api.ToolCall{ - {Function: api.ToolCallFunction{Name: "translate", Arguments: map[string]any{"text": "你好"}}}, + {Function: api.ToolCallFunction{Name: "translate", Arguments: testArgs(map[string]any{"text": "你好"})}}, }, }, {Role: "tool", Content: "Hello"}, @@ -524,9 +524,9 @@ func TestNemotron3NanoRenderer(t *testing.T) { Name: "translate", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "text": {Type: api.PropertyType{"string"}}, - }, + }), }, }, }, diff --git a/model/renderers/olmo3.go b/model/renderers/olmo3.go index c6cdaa722..2e53c0c4c 100644 --- a/model/renderers/olmo3.go +++ b/model/renderers/olmo3.go @@ -100,8 +100,8 @@ func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api. sb.WriteString("(") // Get sorted keys for deterministic output - keys := make([]string, 0, len(tc.Function.Arguments)) - for k := range tc.Function.Arguments { + keys := make([]string, 0, tc.Function.Arguments.Len()) + for k := range tc.Function.Arguments.All() { keys = append(keys, k) } sort.Strings(keys) @@ -110,7 +110,8 @@ func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api. if k > 0 { sb.WriteString(", ") } - value, err := json.Marshal(tc.Function.Arguments[key]) + val, _ := tc.Function.Arguments.Get(key) + value, err := json.Marshal(val) if err != nil { return "", err } diff --git a/model/renderers/olmo3_test.go b/model/renderers/olmo3_test.go index be9c4eac2..bd1f717dc 100644 --- a/model/renderers/olmo3_test.go +++ b/model/renderers/olmo3_test.go @@ -53,9 +53,9 @@ func TestOlmo3Renderer(t *testing.T) { Parameters: api.ToolFunctionParameters{ Type: "object", Required: []string{"location"}, - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "location": {Type: api.PropertyType{"string"}, Description: "The city"}, - }, + }), }, }, }, @@ -80,9 +80,9 @@ func TestOlmo3Renderer(t *testing.T) { Parameters: api.ToolFunctionParameters{ Type: "object", Required: []string{"location"}, - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "location": {Type: api.PropertyType{"string"}, Description: "The city"}, - }, + }), }, }, }, @@ -108,9 +108,9 @@ func TestOlmo3Renderer(t *testing.T) { ID: "call_1", Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "location": "San Francisco", - }, + }), }, }, }, @@ -126,9 +126,9 @@ func TestOlmo3Renderer(t *testing.T) { Parameters: api.ToolFunctionParameters{ Type: "object", Required: []string{"location"}, - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "location": {Type: api.PropertyType{"string"}, Description: "The city"}, - }, + }), }, }, }, @@ -172,14 +172,14 @@ func TestOlmo3Renderer(t *testing.T) { ID: "call_1", Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: map[string]any{"location": "San Francisco"}, + Arguments: testArgs(map[string]any{"location": "San Francisco"}), }, }, { ID: "call_2", Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: map[string]any{"location": "New York"}, + Arguments: testArgs(map[string]any{"location": "New York"}), }, }, }, @@ -194,9 +194,9 @@ func TestOlmo3Renderer(t *testing.T) { Name: "get_weather", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "location": {Type: api.PropertyType{"string"}}, - }, + }), }, }, }, @@ -227,10 +227,10 @@ func TestOlmo3Renderer(t *testing.T) { ID: "call_1", Function: api.ToolCallFunction{ Name: "book_flight", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "from": "SFO", "to": "NYC", - }, + }), }, }, }, @@ -243,10 +243,10 @@ func TestOlmo3Renderer(t *testing.T) { Name: "book_flight", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "from": {Type: api.PropertyType{"string"}}, "to": {Type: api.PropertyType{"string"}}, - }, + }), }, }, }, diff --git a/model/renderers/olmo3_think_test.go b/model/renderers/olmo3_think_test.go index 8bfd5fdce..ba03d8cf2 100644 --- a/model/renderers/olmo3_think_test.go +++ b/model/renderers/olmo3_think_test.go @@ -78,7 +78,7 @@ func TestOlmo3ThinkRenderer(t *testing.T) { ID: "call_1", Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: map[string]any{"location": "San Francisco"}, + Arguments: testArgs(map[string]any{"location": "San Francisco"}), }, }, }, diff --git a/model/renderers/qwen3coder.go b/model/renderers/qwen3coder.go index 18853019c..2b5a5ae95 100644 --- a/model/renderers/qwen3coder.go +++ b/model/renderers/qwen3coder.go @@ -96,7 +96,7 @@ func (r *Qwen3CoderRenderer) Render(messages []api.Message, tools []api.Tool, _ } sb.WriteString("\n") - for name, prop := range tool.Function.Parameters.Properties { + for name, prop := range tool.Function.Parameters.Properties.All() { sb.WriteString("\n") sb.WriteString("\n" + name + "") @@ -147,7 +147,7 @@ func (r *Qwen3CoderRenderer) Render(messages []api.Message, tools []api.Tool, _ } for _, toolCall := range message.ToolCalls { sb.WriteString("\n\n") - for name, value := range toolCall.Function.Arguments { + for name, value := range toolCall.Function.Arguments.All() { valueStr := formatToolCallArgument(value) sb.WriteString("\n\n" + valueStr + "\n") } diff --git a/model/renderers/qwen3coder_test.go b/model/renderers/qwen3coder_test.go index 1addee9e1..b6ca56e75 100644 --- a/model/renderers/qwen3coder_test.go +++ b/model/renderers/qwen3coder_test.go @@ -39,9 +39,9 @@ Hello, how are you?<|im_end|> { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "unit": "fahrenheit", - }, + }), }, }, }, @@ -55,7 +55,7 @@ Hello, how are you?<|im_end|> Description: "Get the current weather in a given location", Parameters: api.ToolFunctionParameters{ Required: []string{"unit"}, - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "unit": {Type: api.PropertyType{"string"}, Enum: []any{"celsius", "fahrenheit"}, Description: "The unit of temperature"}, // TODO(drifkin): add multiple params back once we have predictable // order via some sort of ordered map type (see @@ -63,7 +63,7 @@ Hello, how are you?<|im_end|> /* "location": {Type: api.PropertyType{"string"}, Description: "The city and state, e.g. San Francisco, CA"}, */ - }, + }), }, }}, }, @@ -140,19 +140,19 @@ That sounds nice! What about New York?<|im_end|> {Role: "system", Content: "You are a helpful assistant with access to tools."}, {Role: "user", Content: "call double(1) and triple(2)"}, {Role: "assistant", Content: "I'll call double(1) and triple(2) for you.", ToolCalls: []api.ToolCall{ - {Function: api.ToolCallFunction{Name: "double", Arguments: map[string]any{"number": "1"}}}, - {Function: api.ToolCallFunction{Name: "triple", Arguments: map[string]any{"number": "2"}}}, + {Function: api.ToolCallFunction{Name: "double", Arguments: testArgs(map[string]any{"number": "1"})}}, + {Function: api.ToolCallFunction{Name: "triple", Arguments: testArgs(map[string]any{"number": "2"})}}, }}, {Role: "tool", Content: "{\"number\": 2}", ToolName: "double"}, {Role: "tool", Content: "{\"number\": 6}", ToolName: "triple"}, }, tools: []api.Tool{ - {Function: api.ToolFunction{Name: "double", Description: "Double a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{ + {Function: api.ToolFunction{Name: "double", Description: "Double a number", Parameters: api.ToolFunctionParameters{Properties: testPropsMap(map[string]api.ToolProperty{ "number": {Type: api.PropertyType{"string"}, Description: "The number to double"}, - }}}}, - {Function: api.ToolFunction{Name: "triple", Description: "Triple a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{ + })}}}, + {Function: api.ToolFunction{Name: "triple", Description: "Triple a number", Parameters: api.ToolFunctionParameters{Properties: testPropsMap(map[string]api.ToolProperty{ "number": {Type: api.PropertyType{"string"}, Description: "The number to triple"}, - }}}}, + })}}}, }, expected: `<|im_start|>system You are a helpful assistant with access to tools. @@ -259,9 +259,9 @@ I'll tell you something interesting about cats`, {Role: "assistant", ToolCalls: []api.ToolCall{ {Function: api.ToolCallFunction{ Name: "echo", - Arguments: map[string]any{ + Arguments: testArgs(map[string]any{ "payload": map[string]any{"foo": "bar"}, - }, + }), }}, }}, {Role: "tool", Content: "{\"payload\": {\"foo\": \"bar\"}}", ToolName: "echo"}, diff --git a/model/renderers/qwen3vl_nonthinking_test.go b/model/renderers/qwen3vl_nonthinking_test.go index d3377e39d..70ba68645 100644 --- a/model/renderers/qwen3vl_nonthinking_test.go +++ b/model/renderers/qwen3vl_nonthinking_test.go @@ -337,7 +337,7 @@ Let me analyze this image.`, Role: "assistant", Content: "I'll check.", ToolCalls: []api.ToolCall{ - {Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: map[string]any{"location": "Paris", "unit": "celsius"}}}, + {Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: testArgsOrdered([]orderedArg{{"location", "Paris"}, {"unit", "celsius"}})}}, }, }, {Role: "user", Content: "\n18\n"}, @@ -367,8 +367,8 @@ Thanks!<|im_end|> Role: "assistant", Content: "before", ToolCalls: []api.ToolCall{ - {Function: api.ToolCallFunction{Name: "add", Arguments: map[string]any{"a": 2, "b": 3}}}, - {Function: api.ToolCallFunction{Name: "mul", Arguments: map[string]any{"x": 4, "y": 5}}}, + {Function: api.ToolCallFunction{Name: "add", Arguments: testArgsOrdered([]orderedArg{{"a", 2}, {"b", 3}})}}, + {Function: api.ToolCallFunction{Name: "mul", Arguments: testArgsOrdered([]orderedArg{{"x", 4}, {"y", 5}})}}, }, }, }, @@ -387,7 +387,7 @@ before name: "consecutive tool responses grouped", msgs: []api.Message{ {Role: "user", Content: "Compute results"}, - {Role: "assistant", Content: "ok", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "job", Arguments: map[string]any{"n": 1}}}}}, + {Role: "assistant", Content: "ok", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "job", Arguments: testArgs(map[string]any{"n": 1})}}}}, {Role: "tool", Content: "5", ToolName: "job"}, {Role: "tool", Content: "6", ToolName: "job"}, }, @@ -412,7 +412,7 @@ ok name: "last message is tool then prefill", msgs: []api.Message{ {Role: "user", Content: "run"}, - {Role: "assistant", Content: "ok", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "exec", Arguments: map[string]any{"cmd": "ls"}}}}}, + {Role: "assistant", Content: "ok", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "exec", Arguments: testArgs(map[string]any{"cmd": "ls"})}}}}, {Role: "tool", Content: "done", ToolName: "exec"}, }, expected: `<|im_start|>user @@ -447,7 +447,7 @@ done Role: "assistant", Content: "I'll check.", ToolCalls: []api.ToolCall{ - {Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: map[string]any{"location": "Paris", "unit": "celsius"}}}, + {Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: testArgsOrdered([]orderedArg{{"location", "Paris"}, {"unit", "celsius"}})}}, }, }, {Role: "user", Content: "\n18\n"}, @@ -477,7 +477,7 @@ Thanks!<|im_end|> Role: "assistant", Content: "I'll check.", ToolCalls: []api.ToolCall{ - {Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: map[string]any{"location": "Paris", "unit": "celsius"}}}, + {Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: testArgsOrdered([]orderedArg{{"location", "Paris"}, {"unit", "celsius"}})}}, }, }, {Role: "user", Content: "\n\n\n\n\n18\n extra\n\n\n\n\n\n"}, diff --git a/model/renderers/qwen3vl_thinking_test.go b/model/renderers/qwen3vl_thinking_test.go index eb53e6a92..7fc3f2af6 100644 --- a/model/renderers/qwen3vl_thinking_test.go +++ b/model/renderers/qwen3vl_thinking_test.go @@ -128,10 +128,10 @@ Speak poetry after the first sentence.Speak poetry after the seco // { // Function: api.ToolCallFunction{ // Name: "get-current-weather", - // Arguments: map[string]any{ + // Arguments: testArgs(map[string]any{ // "location": "New York", // "unit": "fahrenheit", - // }, + // }), // }, // }, // }, @@ -148,7 +148,7 @@ Speak poetry after the first sentence.Speak poetry after the seco // Parameters: api.ToolFunctionParameters{ // Type: "object", // Required: []string{"location"}, - // Properties: map[string]api.ToolProperty{ + // Properties: testPropsMap(map[string]api.ToolProperty{ // "location": { // Type: api.PropertyType{"string"}, // Description: "The city and state, e.g. San Francisco, CA", @@ -158,7 +158,7 @@ Speak poetry after the first sentence.Speak poetry after the seco // Enum: []any{"celsius", "fahrenheit"}, // Description: "The temperature unit", // }, - // }, + // }), // }, // }, // }, @@ -216,19 +216,19 @@ Speak poetry after the first sentence.Speak poetry after the seco // { // Function: api.ToolCallFunction{ // Name: "add", - // Arguments: map[string]any{ + // Arguments: testArgs(map[string]any{ // "a": 2, // "b": 3, - // }, + // }), // }, // }, // { // Function: api.ToolCallFunction{ // Name: "multiply", - // Arguments: map[string]any{ + // Arguments: testArgs(map[string]any{ // "x": 4, // "y": 5, - // }, + // }), // }, // }, // }, @@ -257,10 +257,10 @@ Speak poetry after the first sentence.Speak poetry after the seco // Parameters: api.ToolFunctionParameters{ // Type: "object", // Required: []string{"a", "b"}, - // Properties: map[string]api.ToolProperty{ + // Properties: testPropsMap(map[string]api.ToolProperty{ // "a": {Type: api.PropertyType{"integer"}, Description: "First number"}, // "b": {Type: api.PropertyType{"integer"}, Description: "Second number"}, - // }, + // }), // }, // }, // }, @@ -272,10 +272,10 @@ Speak poetry after the first sentence.Speak poetry after the seco // Parameters: api.ToolFunctionParameters{ // Type: "object", // Required: []string{"x", "y"}, - // Properties: map[string]api.ToolProperty{ + // Properties: testPropsMap(map[string]api.ToolProperty{ // "x": {Type: api.PropertyType{"integer"}, Description: "First factor"}, // "y": {Type: api.PropertyType{"integer"}, Description: "Second factor"}, - // }, + // }), // }, // }, // }, diff --git a/model/renderers/testhelpers_test.go b/model/renderers/testhelpers_test.go new file mode 100644 index 000000000..6eac8eee4 --- /dev/null +++ b/model/renderers/testhelpers_test.go @@ -0,0 +1,36 @@ +package renderers + +import "github.com/ollama/ollama/api" + +// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved) +func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap { + props := api.NewToolPropertiesMap() + for k, v := range m { + props.Set(k, v) + } + return props +} + +// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests, order not preserved) +func testArgs(m map[string]any) api.ToolCallFunctionArguments { + args := api.NewToolCallFunctionArguments() + for k, v := range m { + args.Set(k, v) + } + return args +} + +// orderedArg represents a key-value pair for ordered argument creation +type orderedArg struct { + Key string + Value any +} + +// testArgsOrdered creates ToolCallFunctionArguments with a specific key order +func testArgsOrdered(pairs []orderedArg) api.ToolCallFunctionArguments { + args := api.NewToolCallFunctionArguments() + for _, p := range pairs { + args.Set(p.Key, p.Value) + } + return args +} diff --git a/openai/openai_test.go b/openai/openai_test.go index 51e243dec..f76af7090 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -10,6 +10,20 @@ import ( "github.com/ollama/ollama/api" ) +// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests) +func testArgs(m map[string]any) api.ToolCallFunctionArguments { + args := api.NewToolCallFunctionArguments() + for k, v := range m { + args.Set(k, v) + } + return args +} + +// argsComparer provides cmp options for comparing ToolCallFunctionArguments by value +var argsComparer = cmp.Comparer(func(a, b api.ToolCallFunctionArguments) bool { + return cmp.Equal(a.ToMap(), b.ToMap()) +}) + const ( prefix = `data:image/jpeg;base64,` image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` @@ -159,9 +173,9 @@ func TestToToolCallsPreservesIDs(t *testing.T) { Function: api.ToolCallFunction{ Index: 2, Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Seattle", - }, + }), }, }, { @@ -169,9 +183,9 @@ func TestToToolCallsPreservesIDs(t *testing.T) { Function: api.ToolCallFunction{ Index: 7, Name: "get_time", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "timezone": "UTC", - }, + }), }, }, } @@ -215,7 +229,7 @@ func TestToToolCallsPreservesIDs(t *testing.T) { t.Errorf("tool calls mismatch (-want +got):\n%s", diff) } - if diff := cmp.Diff(original, toolCalls); diff != "" { + if diff := cmp.Diff(original, toolCalls, argsComparer); diff != "" { t.Errorf("input tool calls mutated (-want +got):\n%s", diff) } } diff --git a/openai/responses_test.go b/openai/responses_test.go index 86731e72b..bfb6bb36e 100644 --- a/openai/responses_test.go +++ b/openai/responses_test.go @@ -925,7 +925,7 @@ func TestResponsesStreamConverter_ToolCalls(t *testing.T) { ID: "call_abc", Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{"city": "Paris"}, + Arguments: testArgs(map[string]any{"city": "Paris"}), }, }, }, @@ -1800,7 +1800,7 @@ func TestResponsesStreamConverter_FunctionCallStatus(t *testing.T) { ID: "call_abc", Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{"city": "Paris"}, + Arguments: testArgs(map[string]any{"city": "Paris"}), }, }, }, diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index 13befff2a..111a9678a 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -22,6 +22,29 @@ import ( "github.com/ollama/ollama/ml" ) +// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests) +func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap { + props := api.NewToolPropertiesMap() + for k, v := range m { + props.Set(k, v) + } + return props +} + +// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests) +func testArgs(m map[string]any) api.ToolCallFunctionArguments { + args := api.NewToolCallFunctionArguments() + for k, v := range m { + args.Set(k, v) + } + return args +} + +// argsComparer provides cmp options for comparing ToolCallFunctionArguments by value +var argsComparer = cmp.Comparer(func(a, b api.ToolCallFunctionArguments) bool { + return cmp.Equal(a.ToMap(), b.ToMap()) +}) + type mockRunner struct { llm.LlamaServer @@ -488,7 +511,7 @@ func TestGenerateChat(t *testing.T) { Parameters: api.ToolFunctionParameters{ Type: "object", Required: []string{"location"}, - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "location": { Type: api.PropertyType{"string"}, Description: "The city and state", @@ -497,7 +520,7 @@ func TestGenerateChat(t *testing.T) { Type: api.PropertyType{"string"}, Enum: []any{"celsius", "fahrenheit"}, }, - }, + }), }, }, }, @@ -559,15 +582,15 @@ func TestGenerateChat(t *testing.T) { expectedToolCall := api.ToolCall{ Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Seattle, WA", "unit": "celsius", - }, + }), }, } expectedToolCall.ID = gotToolCall.ID - if diff := cmp.Diff(gotToolCall, expectedToolCall); diff != "" { + if diff := cmp.Diff(gotToolCall, expectedToolCall, argsComparer); diff != "" { t.Errorf("tool call mismatch (-got +want):\n%s", diff) } }) @@ -582,7 +605,7 @@ func TestGenerateChat(t *testing.T) { Parameters: api.ToolFunctionParameters{ Type: "object", Required: []string{"location"}, - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "location": { Type: api.PropertyType{"string"}, Description: "The city and state", @@ -591,7 +614,7 @@ func TestGenerateChat(t *testing.T) { Type: api.PropertyType{"string"}, Enum: []any{"celsius", "fahrenheit"}, }, - }, + }), }, }, }, @@ -688,10 +711,10 @@ func TestGenerateChat(t *testing.T) { expectedToolCall := api.ToolCall{ Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Seattle, WA", "unit": "celsius", - }, + }), }, } @@ -703,7 +726,7 @@ func TestGenerateChat(t *testing.T) { } expectedToolCall.ID = finalToolCall.ID - if diff := cmp.Diff(finalToolCall, expectedToolCall); diff != "" { + if diff := cmp.Diff(finalToolCall, expectedToolCall, argsComparer); diff != "" { t.Errorf("final tool call mismatch (-got +want):\n%s", diff) } }) @@ -716,9 +739,9 @@ func TestGenerateChat(t *testing.T) { Name: "get_weather", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "location": {Type: api.PropertyType{"string"}}, - }, + }), }, }, }, diff --git a/server/routes_harmony_streaming_test.go b/server/routes_harmony_streaming_test.go index 1fb41ff48..de130c8c8 100644 --- a/server/routes_harmony_streaming_test.go +++ b/server/routes_harmony_streaming_test.go @@ -29,12 +29,12 @@ func getTestTools() []api.Tool { Parameters: api.ToolFunctionParameters{ Type: "object", Required: []string{"location"}, - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "location": { Type: api.PropertyType{"string"}, Description: "The city and state, e.g. San Francisco, CA", }, - }, + }), }, }, }, @@ -46,12 +46,12 @@ func getTestTools() []api.Tool { Parameters: api.ToolFunctionParameters{ Type: "object", Required: []string{"expression"}, - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "expression": { Type: api.PropertyType{"string"}, Description: "The mathematical expression to calculate", }, - }, + }), }, }, }, @@ -185,9 +185,9 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { { Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "San Francisco", - }, + }), }, }, }, @@ -211,9 +211,9 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { { Function: api.ToolCallFunction{ Name: "calculate", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "expression": "2+2", - }, + }), }, }, }, diff --git a/template/template.go b/template/template.go index 39b6ad7b0..9bcec1a7e 100644 --- a/template/template.go +++ b/template/template.go @@ -272,8 +272,8 @@ func (t *Template) Execute(w io.Writer, v Values) error { } else if !v.forceLegacy && slices.Contains(vars, "messages") { return t.Template.Execute(w, map[string]any{ "System": system, - "Messages": messages, - "Tools": v.Tools, + "Messages": convertMessagesForTemplate(messages), + "Tools": convertToolsForTemplate(v.Tools), "Response": "", "Think": v.Think, "ThinkLevel": v.ThinkLevel, @@ -373,6 +373,118 @@ func collate(msgs []api.Message) (string, []*api.Message) { return strings.Join(system, "\n\n"), collated } +// templateTools is a slice of templateTool that marshals to JSON. +type templateTools []templateTool + +func (t templateTools) String() string { + bts, _ := json.Marshal(t) + return string(bts) +} + +// templateTool is a template-compatible representation of api.Tool +// with Properties as a regular map for template ranging. +type templateTool struct { + Type string `json:"type"` + Items any `json:"items,omitempty"` + Function templateToolFunction `json:"function"` +} + +type templateToolFunction struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters templateToolFunctionParameters `json:"parameters"` +} + +type templateToolFunctionParameters struct { + Type string `json:"type"` + Defs any `json:"$defs,omitempty"` + Items any `json:"items,omitempty"` + Required []string `json:"required,omitempty"` + Properties map[string]api.ToolProperty `json:"properties"` +} + +// templateToolCall is a template-compatible representation of api.ToolCall +// with Arguments as a regular map for template ranging. +type templateToolCall struct { + ID string + Function templateToolCallFunction +} + +type templateToolCallFunction struct { + Index int + Name string + Arguments map[string]any +} + +// templateMessage is a template-compatible representation of api.Message +// with ToolCalls converted for template use. +type templateMessage struct { + Role string + Content string + Thinking string + Images []api.ImageData + ToolCalls []templateToolCall + ToolName string + ToolCallID string +} + +// convertToolsForTemplate converts Tools to template-compatible format. +func convertToolsForTemplate(tools api.Tools) templateTools { + if tools == nil { + return nil + } + result := make(templateTools, len(tools)) + for i, tool := range tools { + result[i] = templateTool{ + Type: tool.Type, + Items: tool.Items, + Function: templateToolFunction{ + Name: tool.Function.Name, + Description: tool.Function.Description, + Parameters: templateToolFunctionParameters{ + Type: tool.Function.Parameters.Type, + Defs: tool.Function.Parameters.Defs, + Items: tool.Function.Parameters.Items, + Required: tool.Function.Parameters.Required, + Properties: tool.Function.Parameters.Properties.ToMap(), + }, + }, + } + } + return result +} + +// convertMessagesForTemplate converts Messages to template-compatible format. +func convertMessagesForTemplate(messages []*api.Message) []*templateMessage { + if messages == nil { + return nil + } + result := make([]*templateMessage, len(messages)) + for i, msg := range messages { + var toolCalls []templateToolCall + for _, tc := range msg.ToolCalls { + toolCalls = append(toolCalls, templateToolCall{ + ID: tc.ID, + Function: templateToolCallFunction{ + Index: tc.Function.Index, + Name: tc.Function.Name, + Arguments: tc.Function.Arguments.ToMap(), + }, + }) + } + result[i] = &templateMessage{ + Role: msg.Role, + Content: msg.Content, + Thinking: msg.Thinking, + Images: msg.Images, + ToolCalls: toolCalls, + ToolName: msg.ToolName, + ToolCallID: msg.ToolCallID, + } + } + return result +} + // Identifiers walks the node tree returning any identifiers it finds along the way func Identifiers(n parse.Node) ([]string, error) { switch n := n.(type) { diff --git a/tools/tools.go b/tools/tools.go index 7b8d726b0..b76d1154d 100644 --- a/tools/tools.go +++ b/tools/tools.go @@ -124,16 +124,21 @@ func (p *Parser) parseToolCall() *api.ToolCall { return nil } - var args map[string]any + var argsMap map[string]any if found, i := findArguments(tool, p.buffer); found == nil { return nil } else { - args = found + argsMap = found if i > end { end = i } } + args := api.NewToolCallFunctionArguments() + for k, v := range argsMap { + args.Set(k, v) + } + tc := &api.ToolCall{ Function: api.ToolCallFunction{ Name: tool.Function.Name, diff --git a/tools/tools_test.go b/tools/tools_test.go index b849e2194..2b8b04f8b 100644 --- a/tools/tools_test.go +++ b/tools/tools_test.go @@ -9,6 +9,29 @@ import ( "github.com/ollama/ollama/api" ) +// argsComparer provides cmp options for comparing ToolCallFunctionArguments by value (order-insensitive) +var argsComparer = cmp.Comparer(func(a, b api.ToolCallFunctionArguments) bool { + return cmp.Equal(a.ToMap(), b.ToMap()) +}) + +// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved) +func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap { + props := api.NewToolPropertiesMap() + for k, v := range m { + props.Set(k, v) + } + return props +} + +// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests, order not preserved) +func testArgs(m map[string]any) api.ToolCallFunctionArguments { + args := api.NewToolCallFunctionArguments() + for k, v := range m { + args.Set(k, v) + } + return args +} + func TestParser(t *testing.T) { qwen, err := template.New("qwen").Parse(`{{if .ToolCalls}}{{range .ToolCalls}}{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}}{{end}}`) if err != nil { @@ -44,7 +67,7 @@ func TestParser(t *testing.T) { Parameters: api.ToolFunctionParameters{ Type: "object", Required: []string{"city"}, - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "format": { Type: api.PropertyType{"string"}, Description: "The format to return the temperature in", @@ -54,7 +77,7 @@ func TestParser(t *testing.T) { Type: api.PropertyType{"string"}, Description: "The city to get the temperature for", }, - }, + }), }, }, }, @@ -65,12 +88,12 @@ func TestParser(t *testing.T) { Description: "Retrieve the current weather conditions for a given location", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "location": { Type: api.PropertyType{"string"}, Description: "The location to get the weather conditions for", }, - }, + }), }, }, }, @@ -95,12 +118,12 @@ func TestParser(t *testing.T) { Description: "Get the address of a given location", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "location": { Type: api.PropertyType{"string"}, Description: "The location to get the address for", }, - }, + }), }, }, }, @@ -111,7 +134,7 @@ func TestParser(t *testing.T) { Description: "Add two numbers", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.ToolProperty{ + Properties: testPropsMap(map[string]api.ToolProperty{ "a": { Type: api.PropertyType{"string"}, Description: "The first number to add", @@ -120,7 +143,7 @@ func TestParser(t *testing.T) { Type: api.PropertyType{"string"}, Description: "The second number to add", }, - }, + }), }, }, }, @@ -157,9 +180,9 @@ func TestParser(t *testing.T) { Function: api.ToolCallFunction{ Index: 0, Name: "get_conditions", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "San Francisco", - }, + }), }, }, }, @@ -174,7 +197,7 @@ func TestParser(t *testing.T) { Function: api.ToolCallFunction{ Index: 0, Name: "get_conditions", - Arguments: api.ToolCallFunctionArguments{}, + Arguments: api.NewToolCallFunctionArguments(), }, }, }, @@ -189,9 +212,9 @@ func TestParser(t *testing.T) { Function: api.ToolCallFunction{ Index: 0, Name: "get_temperature", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "city": "New York", - }, + }), }, }, }, @@ -213,19 +236,19 @@ func TestParser(t *testing.T) { Function: api.ToolCallFunction{ Index: 0, Name: "get_temperature", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "city": "London", "format": "fahrenheit", - }, + }), }, }, { Function: api.ToolCallFunction{ Index: 1, Name: "get_conditions", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Tokyo", - }, + }), }, }, }, @@ -240,19 +263,19 @@ func TestParser(t *testing.T) { Function: api.ToolCallFunction{ Index: 0, Name: "get_temperature", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "city": "London", "format": "fahrenheit", - }, + }), }, }, { Function: api.ToolCallFunction{ Index: 1, Name: "get_conditions", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Tokyo", - }, + }), }, }, }, @@ -267,17 +290,17 @@ func TestParser(t *testing.T) { Function: api.ToolCallFunction{ Index: 0, Name: "say_hello", - Arguments: api.ToolCallFunctionArguments{}, + Arguments: api.NewToolCallFunctionArguments(), }, }, { Function: api.ToolCallFunction{ Index: 1, Name: "get_temperature", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "city": "London", "format": "fahrenheit", - }, + }), }, }, }, @@ -292,16 +315,16 @@ func TestParser(t *testing.T) { Function: api.ToolCallFunction{ Index: 0, Name: "get_conditions", - Arguments: api.ToolCallFunctionArguments{}, + Arguments: api.NewToolCallFunctionArguments(), }, }, { Function: api.ToolCallFunction{ Index: 1, Name: "get_conditions", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Tokyo", - }, + }), }, }, }, @@ -316,9 +339,9 @@ func TestParser(t *testing.T) { Function: api.ToolCallFunction{ Index: 0, Name: "get_temperature", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "city": "Tokyo", - }, + }), }, }, }, @@ -347,9 +370,9 @@ func TestParser(t *testing.T) { Function: api.ToolCallFunction{ Index: 0, Name: "get_temperature", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "city": "Tokyo", - }, + }), }, }, }, @@ -371,9 +394,9 @@ func TestParser(t *testing.T) { Function: api.ToolCallFunction{ Index: 0, Name: "get_temperature", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "city": "Tokyo", - }, + }), }, }, }, @@ -453,18 +476,18 @@ func TestParser(t *testing.T) { Function: api.ToolCallFunction{ Index: 0, Name: "get_temperature", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "city": "London", - }, + }), }, }, { Function: api.ToolCallFunction{ Index: 1, Name: "get_conditions", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Tokyo", - }, + }), }, }, }, @@ -486,9 +509,9 @@ func TestParser(t *testing.T) { Function: api.ToolCallFunction{ Index: 0, Name: "get_conditions", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Tokyo", - }, + }), }, }, }, @@ -528,9 +551,9 @@ func TestParser(t *testing.T) { Function: api.ToolCallFunction{ Index: 0, Name: "get_conditions", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "Tokyo", - }, + }), }, }, }, @@ -563,7 +586,7 @@ func TestParser(t *testing.T) { Function: api.ToolCallFunction{ Index: 0, Name: "say_hello_world", - Arguments: api.ToolCallFunctionArguments{}, + Arguments: api.NewToolCallFunctionArguments(), }, }, }, @@ -591,14 +614,14 @@ func TestParser(t *testing.T) { Function: api.ToolCallFunction{ Index: 0, Name: "say_hello_world", - Arguments: api.ToolCallFunctionArguments{}, + Arguments: api.NewToolCallFunctionArguments(), }, }, { Function: api.ToolCallFunction{ Index: 1, Name: "say_hello", - Arguments: api.ToolCallFunctionArguments{}, + Arguments: api.NewToolCallFunctionArguments(), }, }, }, @@ -624,14 +647,14 @@ func TestParser(t *testing.T) { Function: api.ToolCallFunction{ Index: 0, Name: "say_hello", - Arguments: api.ToolCallFunctionArguments{}, + Arguments: api.NewToolCallFunctionArguments(), }, }, { Function: api.ToolCallFunction{ Index: 1, Name: "say_hello_world", - Arguments: api.ToolCallFunctionArguments{}, + Arguments: api.NewToolCallFunctionArguments(), }, }, }, @@ -648,7 +671,7 @@ func TestParser(t *testing.T) { Function: api.ToolCallFunction{ Index: 0, Name: "say_hello", - Arguments: api.ToolCallFunctionArguments{}, + Arguments: api.NewToolCallFunctionArguments(), }, }, }, @@ -665,7 +688,7 @@ func TestParser(t *testing.T) { Function: api.ToolCallFunction{ Index: 0, Name: "say_hello_world", - Arguments: api.ToolCallFunctionArguments{}, + Arguments: api.NewToolCallFunctionArguments(), }, }, }, @@ -687,9 +710,9 @@ func TestParser(t *testing.T) { Function: api.ToolCallFunction{ Index: 0, Name: "get_address", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "London", - }, + }), }, }, }, @@ -706,9 +729,9 @@ func TestParser(t *testing.T) { Function: api.ToolCallFunction{ Index: 0, Name: "get_address", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "location": "London", - }, + }), }, }, }, @@ -725,10 +748,10 @@ func TestParser(t *testing.T) { Function: api.ToolCallFunction{ Index: 0, Name: "add", - Arguments: api.ToolCallFunctionArguments{ + Arguments: testArgs(map[string]any{ "a": "5", "b": "10", - }, + }), }, }, }, @@ -756,7 +779,7 @@ func TestParser(t *testing.T) { } for i, want := range tt.calls { - if diff := cmp.Diff(calls[i], want); diff != "" { + if diff := cmp.Diff(calls[i], want, argsComparer); diff != "" { t.Errorf("Tool call %d mismatch (-got +want):\n%s", i, diff) } } @@ -1316,7 +1339,7 @@ func TestFindArguments(t *testing.T) { got, _ := findArguments(&api.Tool{Function: api.ToolFunction{Name: tt.tool}}, tt.buffer) if diff := cmp.Diff(got, tt.want); diff != "" { - t.Errorf("scanArguments() args mismatch (-got +want):\n%s", diff) + t.Errorf("findArguments() args mismatch (-got +want):\n%s", diff) } }) }